I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,22 @@
from .plot import plot_backends
from .plot_implicit import plot_implicit
from .textplot import textplot
from .pygletplot import PygletPlot
from .plot import PlotGrid
from .plot import (plot, plot_parametric, plot3d, plot3d_parametric_surface,
plot3d_parametric_line, plot_contour)
__all__ = [
'plot_backends',
'plot_implicit',
'textplot',
'PygletPlot',
'PlotGrid',
'plot', 'plot_parametric', 'plot3d', 'plot3d_parametric_surface',
'plot3d_parametric_line', 'plot_contour'
]

View File

@ -0,0 +1,419 @@
from sympy.plotting.series import BaseSeries, GenericDataSeries
from sympy.utilities.exceptions import sympy_deprecation_warning
from sympy.utilities.iterables import is_sequence
__doctest_requires__ = {
('Plot.append', 'Plot.extend'): ['matplotlib'],
}
# Global variable
# Set to False when running tests / doctests so that the plots don't show.
_show = True
def unset_show():
"""
Disable show(). For use in the tests.
"""
global _show
_show = False
def _deprecation_msg_m_a_r_f(attr):
sympy_deprecation_warning(
f"The `{attr}` property is deprecated. The `{attr}` keyword "
"argument should be passed to a plotting function, which generates "
"the appropriate data series. If needed, index the plot object to "
"retrieve a specific data series.",
deprecated_since_version="1.13",
active_deprecations_target="deprecated-markers-annotations-fill-rectangles",
stacklevel=4)
def _create_generic_data_series(**kwargs):
keywords = ["annotations", "markers", "fill", "rectangles"]
series = []
for kw in keywords:
dictionaries = kwargs.pop(kw, [])
if dictionaries is None:
dictionaries = []
if isinstance(dictionaries, dict):
dictionaries = [dictionaries]
for d in dictionaries:
args = d.pop("args", [])
series.append(GenericDataSeries(kw, *args, **d))
return series
class Plot:
"""Base class for all backends. A backend represents the plotting library,
which implements the necessary functionalities in order to use SymPy
plotting functions.
For interactive work the function :func:`plot()` is better suited.
This class permits the plotting of SymPy expressions using numerous
backends (:external:mod:`matplotlib`, textplot, the old pyglet module for SymPy, Google
charts api, etc).
The figure can contain an arbitrary number of plots of SymPy expressions,
lists of coordinates of points, etc. Plot has a private attribute _series that
contains all data series to be plotted (expressions for lines or surfaces,
lists of points, etc (all subclasses of BaseSeries)). Those data series are
instances of classes not imported by ``from sympy import *``.
The customization of the figure is on two levels. Global options that
concern the figure as a whole (e.g. title, xlabel, scale, etc) and
per-data series options (e.g. name) and aesthetics (e.g. color, point shape,
line type, etc.).
The difference between options and aesthetics is that an aesthetic can be
a function of the coordinates (or parameters in a parametric plot). The
supported values for an aesthetic are:
- None (the backend uses default values)
- a constant
- a function of one variable (the first coordinate or parameter)
- a function of two variables (the first and second coordinate or parameters)
- a function of three variables (only in nonparametric 3D plots)
Their implementation depends on the backend so they may not work in some
backends.
If the plot is parametric and the arity of the aesthetic function permits
it the aesthetic is calculated over parameters and not over coordinates.
If the arity does not permit calculation over parameters the calculation is
done over coordinates.
Only cartesian coordinates are supported for the moment, but you can use
the parametric plots to plot in polar, spherical and cylindrical
coordinates.
The arguments for the constructor Plot must be subclasses of BaseSeries.
Any global option can be specified as a keyword argument.
The global options for a figure are:
- title : str
- xlabel : str or Symbol
- ylabel : str or Symbol
- zlabel : str or Symbol
- legend : bool
- xscale : {'linear', 'log'}
- yscale : {'linear', 'log'}
- axis : bool
- axis_center : tuple of two floats or {'center', 'auto'}
- xlim : tuple of two floats
- ylim : tuple of two floats
- aspect_ratio : tuple of two floats or {'auto'}
- autoscale : bool
- margin : float in [0, 1]
- backend : {'default', 'matplotlib', 'text'} or a subclass of BaseBackend
- size : optional tuple of two floats, (width, height); default: None
The per data series options and aesthetics are:
There are none in the base series. See below for options for subclasses.
Some data series support additional aesthetics or options:
:class:`~.LineOver1DRangeSeries`, :class:`~.Parametric2DLineSeries`, and
:class:`~.Parametric3DLineSeries` support the following:
Aesthetics:
- line_color : string, or float, or function, optional
Specifies the color for the plot, which depends on the backend being
used.
For example, if ``MatplotlibBackend`` is being used, then
Matplotlib string colors are acceptable (``"red"``, ``"r"``,
``"cyan"``, ``"c"``, ...).
Alternatively, we can use a float number, 0 < color < 1, wrapped in a
string (for example, ``line_color="0.5"``) to specify grayscale colors.
Alternatively, We can specify a function returning a single
float value: this will be used to apply a color-loop (for example,
``line_color=lambda x: math.cos(x)``).
Note that by setting line_color, it would be applied simultaneously
to all the series.
Options:
- label : str
- steps : bool
- integers_only : bool
:class:`~.SurfaceOver2DRangeSeries` and :class:`~.ParametricSurfaceSeries`
support the following:
Aesthetics:
- surface_color : function which returns a float.
Notes
=====
How the plotting module works:
1. Whenever a plotting function is called, the provided expressions are
processed and a list of instances of the
:class:`~sympy.plotting.series.BaseSeries` class is created, containing
the necessary information to plot the expressions
(e.g. the expression, ranges, series name, ...). Eventually, these
objects will generate the numerical data to be plotted.
2. A subclass of :class:`~.Plot` class is instantiaed (referred to as
backend, from now on), which stores the list of series and the main
attributes of the plot (e.g. axis labels, title, ...).
The backend implements the logic to generate the actual figure with
some plotting library.
3. When the ``show`` command is executed, series are processed one by one
to generate numerical data and add it to the figure. The backend is also
going to set the axis labels, title, ..., according to the values stored
in the Plot instance.
The backend should check if it supports the data series that it is given
(e.g. :class:`TextBackend` supports only
:class:`~sympy.plotting.series.LineOver1DRangeSeries`).
It is the backend responsibility to know how to use the class of data series
that it's given. Note that the current implementation of the ``*Series``
classes is "matplotlib-centric": the numerical data returned by the
``get_points`` and ``get_meshes`` methods is meant to be used directly by
Matplotlib. Therefore, the new backend will have to pre-process the
numerical data to make it compatible with the chosen plotting library.
Keep in mind that future SymPy versions may improve the ``*Series`` classes
in order to return numerical data "non-matplotlib-centric", hence if you code
a new backend you have the responsibility to check if its working on each
SymPy release.
Please explore the :class:`MatplotlibBackend` source code to understand
how a backend should be coded.
In order to be used by SymPy plotting functions, a backend must implement
the following methods:
* show(self): used to loop over the data series, generate the numerical
data, plot it and set the axis labels, title, ...
* save(self, path): used to save the current plot to the specified file
path.
* close(self): used to close the current plot backend (note: some plotting
library does not support this functionality. In that case, just raise a
warning).
"""
def __init__(self, *args,
title=None, xlabel=None, ylabel=None, zlabel=None, aspect_ratio='auto',
xlim=None, ylim=None, axis_center='auto', axis=True,
xscale='linear', yscale='linear', legend=False, autoscale=True,
margin=0, annotations=None, markers=None, rectangles=None,
fill=None, backend='default', size=None, **kwargs):
# Options for the graph as a whole.
# The possible values for each option are described in the docstring of
# Plot. They are based purely on convention, no checking is done.
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.zlabel = zlabel
self.aspect_ratio = aspect_ratio
self.axis_center = axis_center
self.axis = axis
self.xscale = xscale
self.yscale = yscale
self.legend = legend
self.autoscale = autoscale
self.margin = margin
self._annotations = annotations
self._markers = markers
self._rectangles = rectangles
self._fill = fill
# Contains the data objects to be plotted. The backend should be smart
# enough to iterate over this list.
self._series = []
self._series.extend(args)
self._series.extend(_create_generic_data_series(
annotations=annotations, markers=markers, rectangles=rectangles,
fill=fill))
is_real = \
lambda lim: all(getattr(i, 'is_real', True) for i in lim)
is_finite = \
lambda lim: all(getattr(i, 'is_finite', True) for i in lim)
# reduce code repetition
def check_and_set(t_name, t):
if t:
if not is_real(t):
raise ValueError(
"All numbers from {}={} must be real".format(t_name, t))
if not is_finite(t):
raise ValueError(
"All numbers from {}={} must be finite".format(t_name, t))
setattr(self, t_name, (float(t[0]), float(t[1])))
self.xlim = None
check_and_set("xlim", xlim)
self.ylim = None
check_and_set("ylim", ylim)
self.size = None
check_and_set("size", size)
@property
def _backend(self):
return self
@property
def backend(self):
return type(self)
def __str__(self):
series_strs = [('[%d]: ' % i) + str(s)
for i, s in enumerate(self._series)]
return 'Plot object containing:\n' + '\n'.join(series_strs)
def __getitem__(self, index):
return self._series[index]
def __setitem__(self, index, *args):
if len(args) == 1 and isinstance(args[0], BaseSeries):
self._series[index] = args
def __delitem__(self, index):
del self._series[index]
def append(self, arg):
"""Adds an element from a plot's series to an existing plot.
Examples
========
Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the
second plot's first series object to the first, use the
``append`` method, like so:
.. plot::
:format: doctest
:include-source: True
>>> from sympy import symbols
>>> from sympy.plotting import plot
>>> x = symbols('x')
>>> p1 = plot(x*x, show=False)
>>> p2 = plot(x, show=False)
>>> p1.append(p2[0])
>>> p1
Plot object containing:
[0]: cartesian line: x**2 for x over (-10.0, 10.0)
[1]: cartesian line: x for x over (-10.0, 10.0)
>>> p1.show()
See Also
========
extend
"""
if isinstance(arg, BaseSeries):
self._series.append(arg)
else:
raise TypeError('Must specify element of plot to append.')
def extend(self, arg):
"""Adds all series from another plot.
Examples
========
Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the
second plot to the first, use the ``extend`` method, like so:
.. plot::
:format: doctest
:include-source: True
>>> from sympy import symbols
>>> from sympy.plotting import plot
>>> x = symbols('x')
>>> p1 = plot(x**2, show=False)
>>> p2 = plot(x, -x, show=False)
>>> p1.extend(p2)
>>> p1
Plot object containing:
[0]: cartesian line: x**2 for x over (-10.0, 10.0)
[1]: cartesian line: x for x over (-10.0, 10.0)
[2]: cartesian line: -x for x over (-10.0, 10.0)
>>> p1.show()
"""
if isinstance(arg, Plot):
self._series.extend(arg._series)
elif is_sequence(arg):
self._series.extend(arg)
else:
raise TypeError('Expecting Plot or sequence of BaseSeries')
def show(self):
raise NotImplementedError
def save(self, path):
raise NotImplementedError
def close(self):
raise NotImplementedError
# deprecations
@property
def markers(self):
""".. deprecated:: 1.13"""
_deprecation_msg_m_a_r_f("markers")
return self._markers
@markers.setter
def markers(self, v):
""".. deprecated:: 1.13"""
_deprecation_msg_m_a_r_f("markers")
self._series.extend(_create_generic_data_series(markers=v))
self._markers = v
@property
def annotations(self):
""".. deprecated:: 1.13"""
_deprecation_msg_m_a_r_f("annotations")
return self._annotations
@annotations.setter
def annotations(self, v):
""".. deprecated:: 1.13"""
_deprecation_msg_m_a_r_f("annotations")
self._series.extend(_create_generic_data_series(annotations=v))
self._annotations = v
@property
def rectangles(self):
""".. deprecated:: 1.13"""
_deprecation_msg_m_a_r_f("rectangles")
return self._rectangles
@rectangles.setter
def rectangles(self, v):
""".. deprecated:: 1.13"""
_deprecation_msg_m_a_r_f("rectangles")
self._series.extend(_create_generic_data_series(rectangles=v))
self._rectangles = v
@property
def fill(self):
""".. deprecated:: 1.13"""
_deprecation_msg_m_a_r_f("fill")
return self._fill
@fill.setter
def fill(self, v):
""".. deprecated:: 1.13"""
_deprecation_msg_m_a_r_f("fill")
self._series.extend(_create_generic_data_series(fill=v))
self._fill = v

View File

@ -0,0 +1,5 @@
from sympy.plotting.backends.matplotlibbackend.matplotlib import (
MatplotlibBackend, _matplotlib_list
)
__all__ = ["MatplotlibBackend", "_matplotlib_list"]

View File

@ -0,0 +1,318 @@
from collections.abc import Callable
from sympy.core.basic import Basic
from sympy.external import import_module
import sympy.plotting.backends.base_backend as base_backend
from sympy.printing.latex import latex
# N.B.
# When changing the minimum module version for matplotlib, please change
# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
def _str_or_latex(label):
if isinstance(label, Basic):
return latex(label, mode='inline')
return str(label)
def _matplotlib_list(interval_list):
"""
Returns lists for matplotlib ``fill`` command from a list of bounding
rectangular intervals
"""
xlist = []
ylist = []
if len(interval_list):
for intervals in interval_list:
intervalx = intervals[0]
intervaly = intervals[1]
xlist.extend([intervalx.start, intervalx.start,
intervalx.end, intervalx.end, None])
ylist.extend([intervaly.start, intervaly.end,
intervaly.end, intervaly.start, None])
else:
#XXX Ugly hack. Matplotlib does not accept empty lists for ``fill``
xlist.extend((None, None, None, None))
ylist.extend((None, None, None, None))
return xlist, ylist
# Don't have to check for the success of importing matplotlib in each case;
# we will only be using this backend if we can successfully import matploblib
class MatplotlibBackend(base_backend.Plot):
""" This class implements the functionalities to use Matplotlib with SymPy
plotting functions.
"""
def __init__(self, *series, **kwargs):
super().__init__(*series, **kwargs)
self.matplotlib = import_module('matplotlib',
import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
min_module_version='1.1.0', catch=(RuntimeError,))
self.plt = self.matplotlib.pyplot
self.cm = self.matplotlib.cm
self.LineCollection = self.matplotlib.collections.LineCollection
self.aspect = kwargs.get('aspect_ratio', 'auto')
if self.aspect != 'auto':
self.aspect = float(self.aspect[1]) / self.aspect[0]
# PlotGrid can provide its figure and axes to be populated with
# the data from the series.
self._plotgrid_fig = kwargs.pop("fig", None)
self._plotgrid_ax = kwargs.pop("ax", None)
def _create_figure(self):
def set_spines(ax):
ax.spines['left'].set_position('zero')
ax.spines['right'].set_color('none')
ax.spines['bottom'].set_position('zero')
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
if self._plotgrid_fig is not None:
self.fig = self._plotgrid_fig
self.ax = self._plotgrid_ax
if not any(s.is_3D for s in self._series):
set_spines(self.ax)
else:
self.fig = self.plt.figure(figsize=self.size)
if any(s.is_3D for s in self._series):
self.ax = self.fig.add_subplot(1, 1, 1, projection="3d")
else:
self.ax = self.fig.add_subplot(1, 1, 1)
set_spines(self.ax)
@staticmethod
def get_segments(x, y, z=None):
""" Convert two list of coordinates to a list of segments to be used
with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`.
Parameters
==========
x : list
List of x-coordinates
y : list
List of y-coordinates
z : list
List of z-coordinates for a 3D line.
"""
np = import_module('numpy')
if z is not None:
dim = 3
points = (x, y, z)
else:
dim = 2
points = (x, y)
points = np.ma.array(points).T.reshape(-1, 1, dim)
return np.ma.concatenate([points[:-1], points[1:]], axis=1)
def _process_series(self, series, ax):
np = import_module('numpy')
mpl_toolkits = import_module(
'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']})
# XXX Workaround for matplotlib issue
# https://github.com/matplotlib/matplotlib/issues/17130
xlims, ylims, zlims = [], [], []
for s in series:
# Create the collections
if s.is_2Dline:
if s.is_parametric:
x, y, param = s.get_data()
else:
x, y = s.get_data()
if (isinstance(s.line_color, (int, float)) or
callable(s.line_color)):
segments = self.get_segments(x, y)
collection = self.LineCollection(segments)
collection.set_array(s.get_color_array())
ax.add_collection(collection)
else:
lbl = _str_or_latex(s.label)
line, = ax.plot(x, y, label=lbl, color=s.line_color)
elif s.is_contour:
ax.contour(*s.get_data())
elif s.is_3Dline:
x, y, z, param = s.get_data()
if (isinstance(s.line_color, (int, float)) or
callable(s.line_color)):
art3d = mpl_toolkits.mplot3d.art3d
segments = self.get_segments(x, y, z)
collection = art3d.Line3DCollection(segments)
collection.set_array(s.get_color_array())
ax.add_collection(collection)
else:
lbl = _str_or_latex(s.label)
ax.plot(x, y, z, label=lbl, color=s.line_color)
xlims.append(s._xlim)
ylims.append(s._ylim)
zlims.append(s._zlim)
elif s.is_3Dsurface:
if s.is_parametric:
x, y, z, u, v = s.get_data()
else:
x, y, z = s.get_data()
collection = ax.plot_surface(x, y, z,
cmap=getattr(self.cm, 'viridis', self.cm.jet),
rstride=1, cstride=1, linewidth=0.1)
if isinstance(s.surface_color, (float, int, Callable)):
color_array = s.get_color_array()
color_array = color_array.reshape(color_array.size)
collection.set_array(color_array)
else:
collection.set_color(s.surface_color)
xlims.append(s._xlim)
ylims.append(s._ylim)
zlims.append(s._zlim)
elif s.is_implicit:
points = s.get_data()
if len(points) == 2:
# interval math plotting
x, y = _matplotlib_list(points[0])
ax.fill(x, y, facecolor=s.line_color, edgecolor='None')
else:
# use contourf or contour depending on whether it is
# an inequality or equality.
# XXX: ``contour`` plots multiple lines. Should be fixed.
ListedColormap = self.matplotlib.colors.ListedColormap
colormap = ListedColormap(["white", s.line_color])
xarray, yarray, zarray, plot_type = points
if plot_type == 'contour':
ax.contour(xarray, yarray, zarray, cmap=colormap)
else:
ax.contourf(xarray, yarray, zarray, cmap=colormap)
elif s.is_generic:
if s.type == "markers":
# s.rendering_kw["color"] = s.line_color
ax.plot(*s.args, **s.rendering_kw)
elif s.type == "annotations":
ax.annotate(*s.args, **s.rendering_kw)
elif s.type == "fill":
# s.rendering_kw["color"] = s.line_color
ax.fill_between(*s.args, **s.rendering_kw)
elif s.type == "rectangles":
# s.rendering_kw["color"] = s.line_color
ax.add_patch(
self.matplotlib.patches.Rectangle(
*s.args, **s.rendering_kw))
else:
raise NotImplementedError(
'{} is not supported in the SymPy plotting module '
'with matplotlib backend. Please report this issue.'
.format(ax))
Axes3D = mpl_toolkits.mplot3d.Axes3D
if not isinstance(ax, Axes3D):
ax.autoscale_view(
scalex=ax.get_autoscalex_on(),
scaley=ax.get_autoscaley_on())
else:
# XXX Workaround for matplotlib issue
# https://github.com/matplotlib/matplotlib/issues/17130
if xlims:
xlims = np.array(xlims)
xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1]))
ax.set_xlim(xlim)
else:
ax.set_xlim([0, 1])
if ylims:
ylims = np.array(ylims)
ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1]))
ax.set_ylim(ylim)
else:
ax.set_ylim([0, 1])
if zlims:
zlims = np.array(zlims)
zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1]))
ax.set_zlim(zlim)
else:
ax.set_zlim([0, 1])
# Set global options.
# TODO The 3D stuff
# XXX The order of those is important.
if self.xscale and not isinstance(ax, Axes3D):
ax.set_xscale(self.xscale)
if self.yscale and not isinstance(ax, Axes3D):
ax.set_yscale(self.yscale)
if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check
ax.set_autoscale_on(self.autoscale)
if self.axis_center:
val = self.axis_center
if isinstance(ax, Axes3D):
pass
elif val == 'center':
ax.spines['left'].set_position('center')
ax.spines['bottom'].set_position('center')
elif val == 'auto':
xl, xh = ax.get_xlim()
yl, yh = ax.get_ylim()
pos_left = ('data', 0) if xl*xh <= 0 else 'center'
pos_bottom = ('data', 0) if yl*yh <= 0 else 'center'
ax.spines['left'].set_position(pos_left)
ax.spines['bottom'].set_position(pos_bottom)
else:
ax.spines['left'].set_position(('data', val[0]))
ax.spines['bottom'].set_position(('data', val[1]))
if not self.axis:
ax.set_axis_off()
if self.legend:
if ax.legend():
ax.legend_.set_visible(self.legend)
if self.margin:
ax.set_xmargin(self.margin)
ax.set_ymargin(self.margin)
if self.title:
ax.set_title(self.title)
if self.xlabel:
xlbl = _str_or_latex(self.xlabel)
ax.set_xlabel(xlbl, position=(1, 0))
if self.ylabel:
ylbl = _str_or_latex(self.ylabel)
ax.set_ylabel(ylbl, position=(0, 1))
if isinstance(ax, Axes3D) and self.zlabel:
zlbl = _str_or_latex(self.zlabel)
ax.set_zlabel(zlbl, position=(0, 1))
# xlim and ylim should always be set at last so that plot limits
# doesn't get altered during the process.
if self.xlim:
ax.set_xlim(self.xlim)
if self.ylim:
ax.set_ylim(self.ylim)
self.ax.set_aspect(self.aspect)
def process_series(self):
"""
Iterates over every ``Plot`` object and further calls
_process_series()
"""
self._create_figure()
self._process_series(self._series, self.ax)
def show(self):
self.process_series()
#TODO after fixing https://github.com/ipython/ipython/issues/1255
# you can uncomment the next line and remove the pyplot.show() call
#self.fig.show()
if base_backend._show:
self.fig.tight_layout()
self.plt.show()
else:
self.close()
def save(self, path):
self.process_series()
self.fig.savefig(path)
def close(self):
self.plt.close(self.fig)

View File

@ -0,0 +1,3 @@
from sympy.plotting.backends.textbackend.text import TextBackend
__all__ = ["TextBackend"]

View File

@ -0,0 +1,24 @@
import sympy.plotting.backends.base_backend as base_backend
from sympy.plotting.series import LineOver1DRangeSeries
from sympy.plotting.textplot import textplot
class TextBackend(base_backend.Plot):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def show(self):
if not base_backend._show:
return
if len(self._series) != 1:
raise ValueError(
'The TextBackend supports only one graph per Plot.')
elif not isinstance(self._series[0], LineOver1DRangeSeries):
raise ValueError(
'The TextBackend supports only expressions over a 1D range')
else:
ser = self._series[0]
textplot(ser.expr, ser.start, ser.end)
def close(self):
pass

View File

@ -0,0 +1,641 @@
""" rewrite of lambdify - This stuff is not stable at all.
It is for internal use in the new plotting module.
It may (will! see the Q'n'A in the source) be rewritten.
It's completely self contained. Especially it does not use lambdarepr.
It does not aim to replace the current lambdify. Most importantly it will never
ever support anything else than SymPy expressions (no Matrices, dictionaries
and so on).
"""
import re
from sympy.core.numbers import (I, NumberSymbol, oo, zoo)
from sympy.core.symbol import Symbol
from sympy.utilities.iterables import numbered_symbols
# We parse the expression string into a tree that identifies functions. Then
# we translate the names of the functions and we translate also some strings
# that are not names of functions (all this according to translation
# dictionaries).
# If the translation goes to another module (like numpy) the
# module is imported and 'func' is translated to 'module.func'.
# If a function can not be translated, the inner nodes of that part of the
# tree are not translated. So if we have Integral(sqrt(x)), sqrt is not
# translated to np.sqrt and the Integral does not crash.
# A namespace for all this is generated by crawling the (func, args) tree of
# the expression. The creation of this namespace involves many ugly
# workarounds.
# The namespace consists of all the names needed for the SymPy expression and
# all the name of modules used for translation. Those modules are imported only
# as a name (import numpy as np) in order to keep the namespace small and
# manageable.
# Please, if there is a bug, do not try to fix it here! Rewrite this by using
# the method proposed in the last Q'n'A below. That way the new function will
# work just as well, be just as simple, but it wont need any new workarounds.
# If you insist on fixing it here, look at the workarounds in the function
# sympy_expression_namespace and in lambdify.
# Q: Why are you not using Python abstract syntax tree?
# A: Because it is more complicated and not much more powerful in this case.
# Q: What if I have Symbol('sin') or g=Function('f')?
# A: You will break the algorithm. We should use srepr to defend against this?
# The problem with Symbol('sin') is that it will be printed as 'sin'. The
# parser will distinguish it from the function 'sin' because functions are
# detected thanks to the opening parenthesis, but the lambda expression won't
# understand the difference if we have also the sin function.
# The solution (complicated) is to use srepr and maybe ast.
# The problem with the g=Function('f') is that it will be printed as 'f' but in
# the global namespace we have only 'g'. But as the same printer is used in the
# constructor of the namespace there will be no problem.
# Q: What if some of the printers are not printing as expected?
# A: The algorithm wont work. You must use srepr for those cases. But even
# srepr may not print well. All problems with printers should be considered
# bugs.
# Q: What about _imp_ functions?
# A: Those are taken care for by evalf. A special case treatment will work
# faster but it's not worth the code complexity.
# Q: Will ast fix all possible problems?
# A: No. You will always have to use some printer. Even srepr may not work in
# some cases. But if the printer does not work, that should be considered a
# bug.
# Q: Is there same way to fix all possible problems?
# A: Probably by constructing our strings ourself by traversing the (func,
# args) tree and creating the namespace at the same time. That actually sounds
# good.
from sympy.external import import_module
import warnings
#TODO debugging output
class vectorized_lambdify:
""" Return a sufficiently smart, vectorized and lambdified function.
Returns only reals.
Explanation
===========
This function uses experimental_lambdify to created a lambdified
expression ready to be used with numpy. Many of the functions in SymPy
are not implemented in numpy so in some cases we resort to Python cmath or
even to evalf.
The following translations are tried:
only numpy complex
- on errors raised by SymPy trying to work with ndarray:
only Python cmath and then vectorize complex128
When using Python cmath there is no need for evalf or float/complex
because Python cmath calls those.
This function never tries to mix numpy directly with evalf because numpy
does not understand SymPy Float. If this is needed one can use the
float_wrap_evalf/complex_wrap_evalf options of experimental_lambdify or
better one can be explicit about the dtypes that numpy works with.
Check numpy bug http://projects.scipy.org/numpy/ticket/1013 to know what
types of errors to expect.
"""
def __init__(self, args, expr):
self.args = args
self.expr = expr
self.np = import_module('numpy')
self.lambda_func_1 = experimental_lambdify(
args, expr, use_np=True)
self.vector_func_1 = self.lambda_func_1
self.lambda_func_2 = experimental_lambdify(
args, expr, use_python_cmath=True)
self.vector_func_2 = self.np.vectorize(
self.lambda_func_2, otypes=[complex])
self.vector_func = self.vector_func_1
self.failure = False
def __call__(self, *args):
np = self.np
try:
temp_args = (np.array(a, dtype=complex) for a in args)
results = self.vector_func(*temp_args)
results = np.ma.masked_where(
np.abs(results.imag) > 1e-7 * np.abs(results),
results.real, copy=False)
return results
except ValueError:
if self.failure:
raise
self.failure = True
self.vector_func = self.vector_func_2
warnings.warn(
'The evaluation of the expression is problematic. '
'We are trying a failback method that may still work. '
'Please report this as a bug.')
return self.__call__(*args)
class lambdify:
"""Returns the lambdified function.
Explanation
===========
This function uses experimental_lambdify to create a lambdified
expression. It uses cmath to lambdify the expression. If the function
is not implemented in Python cmath, Python cmath calls evalf on those
functions.
"""
def __init__(self, args, expr):
self.args = args
self.expr = expr
self.lambda_func_1 = experimental_lambdify(
args, expr, use_python_cmath=True, use_evalf=True)
self.lambda_func_2 = experimental_lambdify(
args, expr, use_python_math=True, use_evalf=True)
self.lambda_func_3 = experimental_lambdify(
args, expr, use_evalf=True, complex_wrap_evalf=True)
self.lambda_func = self.lambda_func_1
self.failure = False
def __call__(self, args):
try:
#The result can be sympy.Float. Hence wrap it with complex type.
result = complex(self.lambda_func(args))
if abs(result.imag) > 1e-7 * abs(result):
return None
return result.real
except (ZeroDivisionError, OverflowError):
return None
except TypeError as e:
if self.failure:
raise e
if self.lambda_func == self.lambda_func_1:
self.lambda_func = self.lambda_func_2
return self.__call__(args)
self.failure = True
self.lambda_func = self.lambda_func_3
warnings.warn(
'The evaluation of the expression is problematic. '
'We are trying a failback method that may still work. '
'Please report this as a bug.', stacklevel=2)
return self.__call__(args)
def experimental_lambdify(*args, **kwargs):
l = Lambdifier(*args, **kwargs)
return l
class Lambdifier:
def __init__(self, args, expr, print_lambda=False, use_evalf=False,
float_wrap_evalf=False, complex_wrap_evalf=False,
use_np=False, use_python_math=False, use_python_cmath=False,
use_interval=False):
self.print_lambda = print_lambda
self.use_evalf = use_evalf
self.float_wrap_evalf = float_wrap_evalf
self.complex_wrap_evalf = complex_wrap_evalf
self.use_np = use_np
self.use_python_math = use_python_math
self.use_python_cmath = use_python_cmath
self.use_interval = use_interval
# Constructing the argument string
# - check
if not all(isinstance(a, Symbol) for a in args):
raise ValueError('The arguments must be Symbols.')
# - use numbered symbols
syms = numbered_symbols(exclude=expr.free_symbols)
newargs = [next(syms) for _ in args]
expr = expr.xreplace(dict(zip(args, newargs)))
argstr = ', '.join([str(a) for a in newargs])
del syms, newargs, args
# Constructing the translation dictionaries and making the translation
self.dict_str = self.get_dict_str()
self.dict_fun = self.get_dict_fun()
exprstr = str(expr)
newexpr = self.tree2str_translate(self.str2tree(exprstr))
# Constructing the namespaces
namespace = {}
namespace.update(self.sympy_atoms_namespace(expr))
namespace.update(self.sympy_expression_namespace(expr))
# XXX Workaround
# Ugly workaround because Pow(a,Half) prints as sqrt(a)
# and sympy_expression_namespace can not catch it.
from sympy.functions.elementary.miscellaneous import sqrt
namespace.update({'sqrt': sqrt})
namespace.update({'Eq': lambda x, y: x == y})
namespace.update({'Ne': lambda x, y: x != y})
# End workaround.
if use_python_math:
namespace.update({'math': __import__('math')})
if use_python_cmath:
namespace.update({'cmath': __import__('cmath')})
if use_np:
try:
namespace.update({'np': __import__('numpy')})
except ImportError:
raise ImportError(
'experimental_lambdify failed to import numpy.')
if use_interval:
namespace.update({'imath': __import__(
'sympy.plotting.intervalmath', fromlist=['intervalmath'])})
namespace.update({'math': __import__('math')})
# Construct the lambda
if self.print_lambda:
print(newexpr)
eval_str = 'lambda %s : ( %s )' % (argstr, newexpr)
self.eval_str = eval_str
exec("MYNEWLAMBDA = %s" % eval_str, namespace)
self.lambda_func = namespace['MYNEWLAMBDA']
def __call__(self, *args, **kwargs):
return self.lambda_func(*args, **kwargs)
##############################################################################
# Dicts for translating from SymPy to other modules
##############################################################################
###
# builtins
###
# Functions with different names in builtins
builtin_functions_different = {
'Min': 'min',
'Max': 'max',
'Abs': 'abs',
}
# Strings that should be translated
builtin_not_functions = {
'I': '1j',
# 'oo': '1e400',
}
###
# numpy
###
# Functions that are the same in numpy
numpy_functions_same = [
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'exp', 'log',
'sqrt', 'floor', 'conjugate', 'sign',
]
# Functions with different names in numpy
numpy_functions_different = {
"acos": "arccos",
"acosh": "arccosh",
"arg": "angle",
"asin": "arcsin",
"asinh": "arcsinh",
"atan": "arctan",
"atan2": "arctan2",
"atanh": "arctanh",
"ceiling": "ceil",
"im": "imag",
"ln": "log",
"Max": "amax",
"Min": "amin",
"re": "real",
"Abs": "abs",
}
# Strings that should be translated
numpy_not_functions = {
'pi': 'np.pi',
'oo': 'np.inf',
'E': 'np.e',
}
###
# Python math
###
# Functions that are the same in math
math_functions_same = [
'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2',
'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh',
'exp', 'log', 'erf', 'sqrt', 'floor', 'factorial', 'gamma',
]
# Functions with different names in math
math_functions_different = {
'ceiling': 'ceil',
'ln': 'log',
'loggamma': 'lgamma'
}
# Strings that should be translated
math_not_functions = {
'pi': 'math.pi',
'E': 'math.e',
}
###
# Python cmath
###
# Functions that are the same in cmath
cmath_functions_same = [
'sin', 'cos', 'tan', 'asin', 'acos', 'atan',
'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh',
'exp', 'log', 'sqrt',
]
# Functions with different names in cmath
cmath_functions_different = {
'ln': 'log',
'arg': 'phase',
}
# Strings that should be translated
cmath_not_functions = {
'pi': 'cmath.pi',
'E': 'cmath.e',
}
###
# intervalmath
###
interval_not_functions = {
'pi': 'math.pi',
'E': 'math.e'
}
interval_functions_same = [
'sin', 'cos', 'exp', 'tan', 'atan', 'log',
'sqrt', 'cosh', 'sinh', 'tanh', 'floor',
'acos', 'asin', 'acosh', 'asinh', 'atanh',
'Abs', 'And', 'Or'
]
interval_functions_different = {
'Min': 'imin',
'Max': 'imax',
'ceiling': 'ceil',
}
###
# mpmath, etc
###
#TODO
###
# Create the final ordered tuples of dictionaries
###
# For strings
def get_dict_str(self):
dict_str = dict(self.builtin_not_functions)
if self.use_np:
dict_str.update(self.numpy_not_functions)
if self.use_python_math:
dict_str.update(self.math_not_functions)
if self.use_python_cmath:
dict_str.update(self.cmath_not_functions)
if self.use_interval:
dict_str.update(self.interval_not_functions)
return dict_str
# For functions
def get_dict_fun(self):
dict_fun = dict(self.builtin_functions_different)
if self.use_np:
for s in self.numpy_functions_same:
dict_fun[s] = 'np.' + s
for k, v in self.numpy_functions_different.items():
dict_fun[k] = 'np.' + v
if self.use_python_math:
for s in self.math_functions_same:
dict_fun[s] = 'math.' + s
for k, v in self.math_functions_different.items():
dict_fun[k] = 'math.' + v
if self.use_python_cmath:
for s in self.cmath_functions_same:
dict_fun[s] = 'cmath.' + s
for k, v in self.cmath_functions_different.items():
dict_fun[k] = 'cmath.' + v
if self.use_interval:
for s in self.interval_functions_same:
dict_fun[s] = 'imath.' + s
for k, v in self.interval_functions_different.items():
dict_fun[k] = 'imath.' + v
return dict_fun
##############################################################################
# The translator functions, tree parsers, etc.
##############################################################################
def str2tree(self, exprstr):
"""Converts an expression string to a tree.
Explanation
===========
Functions are represented by ('func_name(', tree_of_arguments).
Other expressions are (head_string, mid_tree, tail_str).
Expressions that do not contain functions are directly returned.
Examples
========
>>> from sympy.abc import x, y, z
>>> from sympy import Integral, sin
>>> from sympy.plotting.experimental_lambdify import Lambdifier
>>> str2tree = Lambdifier([x], x).str2tree
>>> str2tree(str(Integral(x, (x, 1, y))))
('', ('Integral(', 'x, (x, 1, y)'), ')')
>>> str2tree(str(x+y))
'x + y'
>>> str2tree(str(x+y*sin(z)+1))
('x + y*', ('sin(', 'z'), ') + 1')
>>> str2tree('sin(y*(y + 1.1) + (sin(y)))')
('', ('sin(', ('y*(y + 1.1) + (', ('sin(', 'y'), '))')), ')')
"""
#matches the first 'function_name('
first_par = re.search(r'(\w+\()', exprstr)
if first_par is None:
return exprstr
else:
start = first_par.start()
end = first_par.end()
head = exprstr[:start]
func = exprstr[start:end]
tail = exprstr[end:]
count = 0
for i, c in enumerate(tail):
if c == '(':
count += 1
elif c == ')':
count -= 1
if count == -1:
break
func_tail = self.str2tree(tail[:i])
tail = self.str2tree(tail[i:])
return (head, (func, func_tail), tail)
@classmethod
def tree2str(cls, tree):
"""Converts a tree to string without translations.
Examples
========
>>> from sympy.abc import x, y, z
>>> from sympy import sin
>>> from sympy.plotting.experimental_lambdify import Lambdifier
>>> str2tree = Lambdifier([x], x).str2tree
>>> tree2str = Lambdifier([x], x).tree2str
>>> tree2str(str2tree(str(x+y*sin(z)+1)))
'x + y*sin(z) + 1'
"""
if isinstance(tree, str):
return tree
else:
return ''.join(map(cls.tree2str, tree))
def tree2str_translate(self, tree):
"""Converts a tree to string with translations.
Explanation
===========
Function names are translated by translate_func.
Other strings are translated by translate_str.
"""
if isinstance(tree, str):
return self.translate_str(tree)
elif isinstance(tree, tuple) and len(tree) == 2:
return self.translate_func(tree[0][:-1], tree[1])
else:
return ''.join([self.tree2str_translate(t) for t in tree])
def translate_str(self, estr):
"""Translate substrings of estr using in order the dictionaries in
dict_tuple_str."""
for pattern, repl in self.dict_str.items():
estr = re.sub(pattern, repl, estr)
return estr
def translate_func(self, func_name, argtree):
"""Translate function names and the tree of arguments.
Explanation
===========
If the function name is not in the dictionaries of dict_tuple_fun then the
function is surrounded by a float((...).evalf()).
The use of float is necessary as np.<function>(sympy.Float(..)) raises an
error."""
if func_name in self.dict_fun:
new_name = self.dict_fun[func_name]
argstr = self.tree2str_translate(argtree)
return new_name + '(' + argstr
elif func_name in ['Eq', 'Ne']:
op = {'Eq': '==', 'Ne': '!='}
return "(lambda x, y: x {} y)({}".format(op[func_name], self.tree2str_translate(argtree))
else:
template = '(%s(%s)).evalf(' if self.use_evalf else '%s(%s'
if self.float_wrap_evalf:
template = 'float(%s)' % template
elif self.complex_wrap_evalf:
template = 'complex(%s)' % template
# Wrapping should only happen on the outermost expression, which
# is the only thing we know will be a number.
float_wrap_evalf = self.float_wrap_evalf
complex_wrap_evalf = self.complex_wrap_evalf
self.float_wrap_evalf = False
self.complex_wrap_evalf = False
ret = template % (func_name, self.tree2str_translate(argtree))
self.float_wrap_evalf = float_wrap_evalf
self.complex_wrap_evalf = complex_wrap_evalf
return ret
##############################################################################
# The namespace constructors
##############################################################################
@classmethod
def sympy_expression_namespace(cls, expr):
"""Traverses the (func, args) tree of an expression and creates a SymPy
namespace. All other modules are imported only as a module name. That way
the namespace is not polluted and rests quite small. It probably causes much
more variable lookups and so it takes more time, but there are no tests on
that for the moment."""
if expr is None:
return {}
else:
funcname = str(expr.func)
# XXX Workaround
# Here we add an ugly workaround because str(func(x))
# is not always the same as str(func). Eg
# >>> str(Integral(x))
# "Integral(x)"
# >>> str(Integral)
# "<class 'sympy.integrals.integrals.Integral'>"
# >>> str(sqrt(x))
# "sqrt(x)"
# >>> str(sqrt)
# "<function sqrt at 0x3d92de8>"
# >>> str(sin(x))
# "sin(x)"
# >>> str(sin)
# "sin"
# Either one of those can be used but not all at the same time.
# The code considers the sin example as the right one.
regexlist = [
r'<class \'sympy[\w.]*?.([\w]*)\'>$',
# the example Integral
r'<function ([\w]*) at 0x[\w]*>$', # the example sqrt
]
for r in regexlist:
m = re.match(r, funcname)
if m is not None:
funcname = m.groups()[0]
# End of the workaround
# XXX debug: print funcname
args_dict = {}
for a in expr.args:
if (isinstance(a, (Symbol, NumberSymbol)) or a in [I, zoo, oo]):
continue
else:
args_dict.update(cls.sympy_expression_namespace(a))
args_dict.update({funcname: expr.func})
return args_dict
@staticmethod
def sympy_atoms_namespace(expr):
"""For no real reason this function is separated from
sympy_expression_namespace. It can be moved to it."""
atoms = expr.atoms(Symbol, NumberSymbol, I, zoo, oo)
d = {}
for a in atoms:
# XXX debug: print 'atom:' + str(a)
d[str(a)] = a
return d

View File

@ -0,0 +1,12 @@
from .interval_arithmetic import interval
from .lib_interval import (Abs, exp, log, log10, sin, cos, tan, sqrt,
imin, imax, sinh, cosh, tanh, acosh, asinh, atanh,
asin, acos, atan, ceil, floor, And, Or)
__all__ = [
'interval',
'Abs', 'exp', 'log', 'log10', 'sin', 'cos', 'tan', 'sqrt', 'imin', 'imax',
'sinh', 'cosh', 'tanh', 'acosh', 'asinh', 'atanh', 'asin', 'acos', 'atan',
'ceil', 'floor', 'And', 'Or',
]

View File

@ -0,0 +1,413 @@
"""
Interval Arithmetic for plotting.
This module does not implement interval arithmetic accurately and
hence cannot be used for purposes other than plotting. If you want
to use interval arithmetic, use mpmath's interval arithmetic.
The module implements interval arithmetic using numpy and
python floating points. The rounding up and down is not handled
and hence this is not an accurate implementation of interval
arithmetic.
The module uses numpy for speed which cannot be achieved with mpmath.
"""
# Q: Why use numpy? Why not simply use mpmath's interval arithmetic?
# A: mpmath's interval arithmetic simulates a floating point unit
# and hence is slow, while numpy evaluations are orders of magnitude
# faster.
# Q: Why create a separate class for intervals? Why not use SymPy's
# Interval Sets?
# A: The functionalities that will be required for plotting is quite
# different from what Interval Sets implement.
# Q: Why is rounding up and down according to IEEE754 not handled?
# A: It is not possible to do it in both numpy and python. An external
# library has to used, which defeats the whole purpose i.e., speed. Also
# rounding is handled for very few functions in those libraries.
# Q Will my plots be affected?
# A It will not affect most of the plots. The interval arithmetic
# module based suffers the same problems as that of floating point
# arithmetic.
from sympy.core.numbers import int_valued
from sympy.core.logic import fuzzy_and
from sympy.simplify.simplify import nsimplify
from .interval_membership import intervalMembership
class interval:
""" Represents an interval containing floating points as start and
end of the interval
The is_valid variable tracks whether the interval obtained as the
result of the function is in the domain and is continuous.
- True: Represents the interval result of a function is continuous and
in the domain of the function.
- False: The interval argument of the function was not in the domain of
the function, hence the is_valid of the result interval is False
- None: The function was not continuous over the interval or
the function's argument interval is partly in the domain of the
function
A comparison between an interval and a real number, or a
comparison between two intervals may return ``intervalMembership``
of two 3-valued logic values.
"""
def __init__(self, *args, is_valid=True, **kwargs):
self.is_valid = is_valid
if len(args) == 1:
if isinstance(args[0], interval):
self.start, self.end = args[0].start, args[0].end
else:
self.start = float(args[0])
self.end = float(args[0])
elif len(args) == 2:
if args[0] < args[1]:
self.start = float(args[0])
self.end = float(args[1])
else:
self.start = float(args[1])
self.end = float(args[0])
else:
raise ValueError("interval takes a maximum of two float values "
"as arguments")
@property
def mid(self):
return (self.start + self.end) / 2.0
@property
def width(self):
return self.end - self.start
def __repr__(self):
return "interval(%f, %f)" % (self.start, self.end)
def __str__(self):
return "[%f, %f]" % (self.start, self.end)
def __lt__(self, other):
if isinstance(other, (int, float)):
if self.end < other:
return intervalMembership(True, self.is_valid)
elif self.start > other:
return intervalMembership(False, self.is_valid)
else:
return intervalMembership(None, self.is_valid)
elif isinstance(other, interval):
valid = fuzzy_and([self.is_valid, other.is_valid])
if self.end < other. start:
return intervalMembership(True, valid)
if self.start > other.end:
return intervalMembership(False, valid)
return intervalMembership(None, valid)
else:
return NotImplemented
def __gt__(self, other):
if isinstance(other, (int, float)):
if self.start > other:
return intervalMembership(True, self.is_valid)
elif self.end < other:
return intervalMembership(False, self.is_valid)
else:
return intervalMembership(None, self.is_valid)
elif isinstance(other, interval):
return other.__lt__(self)
else:
return NotImplemented
def __eq__(self, other):
if isinstance(other, (int, float)):
if self.start == other and self.end == other:
return intervalMembership(True, self.is_valid)
if other in self:
return intervalMembership(None, self.is_valid)
else:
return intervalMembership(False, self.is_valid)
if isinstance(other, interval):
valid = fuzzy_and([self.is_valid, other.is_valid])
if self.start == other.start and self.end == other.end:
return intervalMembership(True, valid)
elif self.__lt__(other)[0] is not None:
return intervalMembership(False, valid)
else:
return intervalMembership(None, valid)
else:
return NotImplemented
def __ne__(self, other):
if isinstance(other, (int, float)):
if self.start == other and self.end == other:
return intervalMembership(False, self.is_valid)
if other in self:
return intervalMembership(None, self.is_valid)
else:
return intervalMembership(True, self.is_valid)
if isinstance(other, interval):
valid = fuzzy_and([self.is_valid, other.is_valid])
if self.start == other.start and self.end == other.end:
return intervalMembership(False, valid)
if not self.__lt__(other)[0] is None:
return intervalMembership(True, valid)
return intervalMembership(None, valid)
else:
return NotImplemented
def __le__(self, other):
if isinstance(other, (int, float)):
if self.end <= other:
return intervalMembership(True, self.is_valid)
if self.start > other:
return intervalMembership(False, self.is_valid)
else:
return intervalMembership(None, self.is_valid)
if isinstance(other, interval):
valid = fuzzy_and([self.is_valid, other.is_valid])
if self.end <= other.start:
return intervalMembership(True, valid)
if self.start > other.end:
return intervalMembership(False, valid)
return intervalMembership(None, valid)
else:
return NotImplemented
def __ge__(self, other):
if isinstance(other, (int, float)):
if self.start >= other:
return intervalMembership(True, self.is_valid)
elif self.end < other:
return intervalMembership(False, self.is_valid)
else:
return intervalMembership(None, self.is_valid)
elif isinstance(other, interval):
return other.__le__(self)
def __add__(self, other):
if isinstance(other, (int, float)):
if self.is_valid:
return interval(self.start + other, self.end + other)
else:
start = self.start + other
end = self.end + other
return interval(start, end, is_valid=self.is_valid)
elif isinstance(other, interval):
start = self.start + other.start
end = self.end + other.end
valid = fuzzy_and([self.is_valid, other.is_valid])
return interval(start, end, is_valid=valid)
else:
return NotImplemented
__radd__ = __add__
def __sub__(self, other):
if isinstance(other, (int, float)):
start = self.start - other
end = self.end - other
return interval(start, end, is_valid=self.is_valid)
elif isinstance(other, interval):
start = self.start - other.end
end = self.end - other.start
valid = fuzzy_and([self.is_valid, other.is_valid])
return interval(start, end, is_valid=valid)
else:
return NotImplemented
def __rsub__(self, other):
if isinstance(other, (int, float)):
start = other - self.end
end = other - self.start
return interval(start, end, is_valid=self.is_valid)
elif isinstance(other, interval):
return other.__sub__(self)
else:
return NotImplemented
def __neg__(self):
if self.is_valid:
return interval(-self.end, -self.start)
else:
return interval(-self.end, -self.start, is_valid=self.is_valid)
def __mul__(self, other):
if isinstance(other, interval):
if self.is_valid is False or other.is_valid is False:
return interval(-float('inf'), float('inf'), is_valid=False)
elif self.is_valid is None or other.is_valid is None:
return interval(-float('inf'), float('inf'), is_valid=None)
else:
inters = []
inters.append(self.start * other.start)
inters.append(self.end * other.start)
inters.append(self.start * other.end)
inters.append(self.end * other.end)
start = min(inters)
end = max(inters)
return interval(start, end)
elif isinstance(other, (int, float)):
return interval(self.start*other, self.end*other, is_valid=self.is_valid)
else:
return NotImplemented
__rmul__ = __mul__
def __contains__(self, other):
if isinstance(other, (int, float)):
return self.start <= other and self.end >= other
else:
return self.start <= other.start and other.end <= self.end
def __rtruediv__(self, other):
if isinstance(other, (int, float)):
other = interval(other)
return other.__truediv__(self)
elif isinstance(other, interval):
return other.__truediv__(self)
else:
return NotImplemented
def __truediv__(self, other):
# Both None and False are handled
if not self.is_valid:
# Don't divide as the value is not valid
return interval(-float('inf'), float('inf'), is_valid=self.is_valid)
if isinstance(other, (int, float)):
if other == 0:
# Divide by zero encountered. valid nowhere
return interval(-float('inf'), float('inf'), is_valid=False)
else:
return interval(self.start / other, self.end / other)
elif isinstance(other, interval):
if other.is_valid is False or self.is_valid is False:
return interval(-float('inf'), float('inf'), is_valid=False)
elif other.is_valid is None or self.is_valid is None:
return interval(-float('inf'), float('inf'), is_valid=None)
else:
# denominator contains both signs, i.e. being divided by zero
# return the whole real line with is_valid = None
if 0 in other:
return interval(-float('inf'), float('inf'), is_valid=None)
# denominator negative
this = self
if other.end < 0:
this = -this
other = -other
# denominator positive
inters = []
inters.append(this.start / other.start)
inters.append(this.end / other.start)
inters.append(this.start / other.end)
inters.append(this.end / other.end)
start = max(inters)
end = min(inters)
return interval(start, end)
else:
return NotImplemented
def __pow__(self, other):
# Implements only power to an integer.
from .lib_interval import exp, log
if not self.is_valid:
return self
if isinstance(other, interval):
return exp(other * log(self))
elif isinstance(other, (float, int)):
if other < 0:
return 1 / self.__pow__(abs(other))
else:
if int_valued(other):
return _pow_int(self, other)
else:
return _pow_float(self, other)
else:
return NotImplemented
def __rpow__(self, other):
if isinstance(other, (float, int)):
if not self.is_valid:
#Don't do anything
return self
elif other < 0:
if self.width > 0:
return interval(-float('inf'), float('inf'), is_valid=False)
else:
power_rational = nsimplify(self.start)
num, denom = power_rational.as_numer_denom()
if denom % 2 == 0:
return interval(-float('inf'), float('inf'),
is_valid=False)
else:
start = -abs(other)**self.start
end = start
return interval(start, end)
else:
return interval(other**self.start, other**self.end)
elif isinstance(other, interval):
return other.__pow__(self)
else:
return NotImplemented
def __hash__(self):
return hash((self.is_valid, self.start, self.end))
def _pow_float(inter, power):
"""Evaluates an interval raised to a floating point."""
power_rational = nsimplify(power)
num, denom = power_rational.as_numer_denom()
if num % 2 == 0:
start = abs(inter.start)**power
end = abs(inter.end)**power
if start < 0:
ret = interval(0, max(start, end))
else:
ret = interval(start, end)
return ret
elif denom % 2 == 0:
if inter.end < 0:
return interval(-float('inf'), float('inf'), is_valid=False)
elif inter.start < 0:
return interval(0, inter.end**power, is_valid=None)
else:
return interval(inter.start**power, inter.end**power)
else:
if inter.start < 0:
start = -abs(inter.start)**power
else:
start = inter.start**power
if inter.end < 0:
end = -abs(inter.end)**power
else:
end = inter.end**power
return interval(start, end, is_valid=inter.is_valid)
def _pow_int(inter, power):
"""Evaluates an interval raised to an integer power"""
power = int(power)
if power & 1:
return interval(inter.start**power, inter.end**power)
else:
if inter.start < 0 and inter.end > 0:
start = 0
end = max(inter.start**power, inter.end**power)
return interval(start, end)
else:
return interval(inter.start**power, inter.end**power)

View File

@ -0,0 +1,78 @@
from sympy.core.logic import fuzzy_and, fuzzy_or, fuzzy_not, fuzzy_xor
class intervalMembership:
"""Represents a boolean expression returned by the comparison of
the interval object.
Parameters
==========
(a, b) : (bool, bool)
The first value determines the comparison as follows:
- True: If the comparison is True throughout the intervals.
- False: If the comparison is False throughout the intervals.
- None: If the comparison is True for some part of the intervals.
The second value is determined as follows:
- True: If both the intervals in comparison are valid.
- False: If at least one of the intervals is False, else
- None
"""
def __init__(self, a, b):
self._wrapped = (a, b)
def __getitem__(self, i):
try:
return self._wrapped[i]
except IndexError:
raise IndexError(
"{} must be a valid indexing for the 2-tuple."
.format(i))
def __len__(self):
return 2
def __iter__(self):
return iter(self._wrapped)
def __str__(self):
return "intervalMembership({}, {})".format(*self)
__repr__ = __str__
def __and__(self, other):
if not isinstance(other, intervalMembership):
raise ValueError(
"The comparison is not supported for {}.".format(other))
a1, b1 = self
a2, b2 = other
return intervalMembership(fuzzy_and([a1, a2]), fuzzy_and([b1, b2]))
def __or__(self, other):
if not isinstance(other, intervalMembership):
raise ValueError(
"The comparison is not supported for {}.".format(other))
a1, b1 = self
a2, b2 = other
return intervalMembership(fuzzy_or([a1, a2]), fuzzy_and([b1, b2]))
def __invert__(self):
a, b = self
return intervalMembership(fuzzy_not(a), b)
def __xor__(self, other):
if not isinstance(other, intervalMembership):
raise ValueError(
"The comparison is not supported for {}.".format(other))
a1, b1 = self
a2, b2 = other
return intervalMembership(fuzzy_xor([a1, a2]), fuzzy_and([b1, b2]))
def __eq__(self, other):
return self._wrapped == other
def __ne__(self, other):
return self._wrapped != other

View File

@ -0,0 +1,452 @@
""" The module contains implemented functions for interval arithmetic."""
from functools import reduce
from sympy.plotting.intervalmath import interval
from sympy.external import import_module
def Abs(x):
if isinstance(x, (int, float)):
return interval(abs(x))
elif isinstance(x, interval):
if x.start < 0 and x.end > 0:
return interval(0, max(abs(x.start), abs(x.end)), is_valid=x.is_valid)
else:
return interval(abs(x.start), abs(x.end))
else:
raise NotImplementedError
#Monotonic
def exp(x):
"""evaluates the exponential of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.exp(x), np.exp(x))
elif isinstance(x, interval):
return interval(np.exp(x.start), np.exp(x.end), is_valid=x.is_valid)
else:
raise NotImplementedError
#Monotonic
def log(x):
"""evaluates the natural logarithm of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
if x <= 0:
return interval(-np.inf, np.inf, is_valid=False)
else:
return interval(np.log(x))
elif isinstance(x, interval):
if not x.is_valid:
return interval(-np.inf, np.inf, is_valid=x.is_valid)
elif x.end <= 0:
return interval(-np.inf, np.inf, is_valid=False)
elif x.start <= 0:
return interval(-np.inf, np.inf, is_valid=None)
return interval(np.log(x.start), np.log(x.end))
else:
raise NotImplementedError
#Monotonic
def log10(x):
"""evaluates the logarithm to the base 10 of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
if x <= 0:
return interval(-np.inf, np.inf, is_valid=False)
else:
return interval(np.log10(x))
elif isinstance(x, interval):
if not x.is_valid:
return interval(-np.inf, np.inf, is_valid=x.is_valid)
elif x.end <= 0:
return interval(-np.inf, np.inf, is_valid=False)
elif x.start <= 0:
return interval(-np.inf, np.inf, is_valid=None)
return interval(np.log10(x.start), np.log10(x.end))
else:
raise NotImplementedError
#Monotonic
def atan(x):
"""evaluates the tan inverse of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.arctan(x))
elif isinstance(x, interval):
start = np.arctan(x.start)
end = np.arctan(x.end)
return interval(start, end, is_valid=x.is_valid)
else:
raise NotImplementedError
#periodic
def sin(x):
"""evaluates the sine of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.sin(x))
elif isinstance(x, interval):
if not x.is_valid:
return interval(-1, 1, is_valid=x.is_valid)
na, __ = divmod(x.start, np.pi / 2.0)
nb, __ = divmod(x.end, np.pi / 2.0)
start = min(np.sin(x.start), np.sin(x.end))
end = max(np.sin(x.start), np.sin(x.end))
if nb - na > 4:
return interval(-1, 1, is_valid=x.is_valid)
elif na == nb:
return interval(start, end, is_valid=x.is_valid)
else:
if (na - 1) // 4 != (nb - 1) // 4:
#sin has max
end = 1
if (na - 3) // 4 != (nb - 3) // 4:
#sin has min
start = -1
return interval(start, end)
else:
raise NotImplementedError
#periodic
def cos(x):
"""Evaluates the cos of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.sin(x))
elif isinstance(x, interval):
if not (np.isfinite(x.start) and np.isfinite(x.end)):
return interval(-1, 1, is_valid=x.is_valid)
na, __ = divmod(x.start, np.pi / 2.0)
nb, __ = divmod(x.end, np.pi / 2.0)
start = min(np.cos(x.start), np.cos(x.end))
end = max(np.cos(x.start), np.cos(x.end))
if nb - na > 4:
#differ more than 2*pi
return interval(-1, 1, is_valid=x.is_valid)
elif na == nb:
#in the same quadarant
return interval(start, end, is_valid=x.is_valid)
else:
if (na) // 4 != (nb) // 4:
#cos has max
end = 1
if (na - 2) // 4 != (nb - 2) // 4:
#cos has min
start = -1
return interval(start, end, is_valid=x.is_valid)
else:
raise NotImplementedError
def tan(x):
"""Evaluates the tan of an interval"""
return sin(x) / cos(x)
#Monotonic
def sqrt(x):
"""Evaluates the square root of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
if x > 0:
return interval(np.sqrt(x))
else:
return interval(-np.inf, np.inf, is_valid=False)
elif isinstance(x, interval):
#Outside the domain
if x.end < 0:
return interval(-np.inf, np.inf, is_valid=False)
#Partially outside the domain
elif x.start < 0:
return interval(-np.inf, np.inf, is_valid=None)
else:
return interval(np.sqrt(x.start), np.sqrt(x.end),
is_valid=x.is_valid)
else:
raise NotImplementedError
def imin(*args):
"""Evaluates the minimum of a list of intervals"""
np = import_module('numpy')
if not all(isinstance(arg, (int, float, interval)) for arg in args):
return NotImplementedError
else:
new_args = [a for a in args if isinstance(a, (int, float))
or a.is_valid]
if len(new_args) == 0:
if all(a.is_valid is False for a in args):
return interval(-np.inf, np.inf, is_valid=False)
else:
return interval(-np.inf, np.inf, is_valid=None)
start_array = [a if isinstance(a, (int, float)) else a.start
for a in new_args]
end_array = [a if isinstance(a, (int, float)) else a.end
for a in new_args]
return interval(min(start_array), min(end_array))
def imax(*args):
"""Evaluates the maximum of a list of intervals"""
np = import_module('numpy')
if not all(isinstance(arg, (int, float, interval)) for arg in args):
return NotImplementedError
else:
new_args = [a for a in args if isinstance(a, (int, float))
or a.is_valid]
if len(new_args) == 0:
if all(a.is_valid is False for a in args):
return interval(-np.inf, np.inf, is_valid=False)
else:
return interval(-np.inf, np.inf, is_valid=None)
start_array = [a if isinstance(a, (int, float)) else a.start
for a in new_args]
end_array = [a if isinstance(a, (int, float)) else a.end
for a in new_args]
return interval(max(start_array), max(end_array))
#Monotonic
def sinh(x):
"""Evaluates the hyperbolic sine of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.sinh(x), np.sinh(x))
elif isinstance(x, interval):
return interval(np.sinh(x.start), np.sinh(x.end), is_valid=x.is_valid)
else:
raise NotImplementedError
def cosh(x):
"""Evaluates the hyperbolic cos of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.cosh(x), np.cosh(x))
elif isinstance(x, interval):
#both signs
if x.start < 0 and x.end > 0:
end = max(np.cosh(x.start), np.cosh(x.end))
return interval(1, end, is_valid=x.is_valid)
else:
#Monotonic
start = np.cosh(x.start)
end = np.cosh(x.end)
return interval(start, end, is_valid=x.is_valid)
else:
raise NotImplementedError
#Monotonic
def tanh(x):
"""Evaluates the hyperbolic tan of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.tanh(x), np.tanh(x))
elif isinstance(x, interval):
return interval(np.tanh(x.start), np.tanh(x.end), is_valid=x.is_valid)
else:
raise NotImplementedError
def asin(x):
"""Evaluates the inverse sine of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
#Outside the domain
if abs(x) > 1:
return interval(-np.inf, np.inf, is_valid=False)
else:
return interval(np.arcsin(x), np.arcsin(x))
elif isinstance(x, interval):
#Outside the domain
if x.is_valid is False or x.start > 1 or x.end < -1:
return interval(-np.inf, np.inf, is_valid=False)
#Partially outside the domain
elif x.start < -1 or x.end > 1:
return interval(-np.inf, np.inf, is_valid=None)
else:
start = np.arcsin(x.start)
end = np.arcsin(x.end)
return interval(start, end, is_valid=x.is_valid)
def acos(x):
"""Evaluates the inverse cos of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
if abs(x) > 1:
#Outside the domain
return interval(-np.inf, np.inf, is_valid=False)
else:
return interval(np.arccos(x), np.arccos(x))
elif isinstance(x, interval):
#Outside the domain
if x.is_valid is False or x.start > 1 or x.end < -1:
return interval(-np.inf, np.inf, is_valid=False)
#Partially outside the domain
elif x.start < -1 or x.end > 1:
return interval(-np.inf, np.inf, is_valid=None)
else:
start = np.arccos(x.start)
end = np.arccos(x.end)
return interval(start, end, is_valid=x.is_valid)
def ceil(x):
"""Evaluates the ceiling of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.ceil(x))
elif isinstance(x, interval):
if x.is_valid is False:
return interval(-np.inf, np.inf, is_valid=False)
else:
start = np.ceil(x.start)
end = np.ceil(x.end)
#Continuous over the interval
if start == end:
return interval(start, end, is_valid=x.is_valid)
else:
#Not continuous over the interval
return interval(start, end, is_valid=None)
else:
return NotImplementedError
def floor(x):
"""Evaluates the floor of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.floor(x))
elif isinstance(x, interval):
if x.is_valid is False:
return interval(-np.inf, np.inf, is_valid=False)
else:
start = np.floor(x.start)
end = np.floor(x.end)
#continuous over the argument
if start == end:
return interval(start, end, is_valid=x.is_valid)
else:
#not continuous over the interval
return interval(start, end, is_valid=None)
else:
return NotImplementedError
def acosh(x):
"""Evaluates the inverse hyperbolic cosine of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
#Outside the domain
if x < 1:
return interval(-np.inf, np.inf, is_valid=False)
else:
return interval(np.arccosh(x))
elif isinstance(x, interval):
#Outside the domain
if x.end < 1:
return interval(-np.inf, np.inf, is_valid=False)
#Partly outside the domain
elif x.start < 1:
return interval(-np.inf, np.inf, is_valid=None)
else:
start = np.arccosh(x.start)
end = np.arccosh(x.end)
return interval(start, end, is_valid=x.is_valid)
else:
return NotImplementedError
#Monotonic
def asinh(x):
"""Evaluates the inverse hyperbolic sine of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
return interval(np.arcsinh(x))
elif isinstance(x, interval):
start = np.arcsinh(x.start)
end = np.arcsinh(x.end)
return interval(start, end, is_valid=x.is_valid)
else:
return NotImplementedError
def atanh(x):
"""Evaluates the inverse hyperbolic tangent of an interval"""
np = import_module('numpy')
if isinstance(x, (int, float)):
#Outside the domain
if abs(x) >= 1:
return interval(-np.inf, np.inf, is_valid=False)
else:
return interval(np.arctanh(x))
elif isinstance(x, interval):
#outside the domain
if x.is_valid is False or x.start >= 1 or x.end <= -1:
return interval(-np.inf, np.inf, is_valid=False)
#partly outside the domain
elif x.start <= -1 or x.end >= 1:
return interval(-np.inf, np.inf, is_valid=None)
else:
start = np.arctanh(x.start)
end = np.arctanh(x.end)
return interval(start, end, is_valid=x.is_valid)
else:
return NotImplementedError
#Three valued logic for interval plotting.
def And(*args):
"""Defines the three valued ``And`` behaviour for a 2-tuple of
three valued logic values"""
def reduce_and(cmp_intervala, cmp_intervalb):
if cmp_intervala[0] is False or cmp_intervalb[0] is False:
first = False
elif cmp_intervala[0] is None or cmp_intervalb[0] is None:
first = None
else:
first = True
if cmp_intervala[1] is False or cmp_intervalb[1] is False:
second = False
elif cmp_intervala[1] is None or cmp_intervalb[1] is None:
second = None
else:
second = True
return (first, second)
return reduce(reduce_and, args)
def Or(*args):
"""Defines the three valued ``Or`` behaviour for a 2-tuple of
three valued logic values"""
def reduce_or(cmp_intervala, cmp_intervalb):
if cmp_intervala[0] is True or cmp_intervalb[0] is True:
first = True
elif cmp_intervala[0] is None or cmp_intervalb[0] is None:
first = None
else:
first = False
if cmp_intervala[1] is True or cmp_intervalb[1] is True:
second = True
elif cmp_intervala[1] is None or cmp_intervalb[1] is None:
second = None
else:
second = False
return (first, second)
return reduce(reduce_or, args)

View File

@ -0,0 +1,415 @@
from sympy.external import import_module
from sympy.plotting.intervalmath import (
Abs, acos, acosh, And, asin, asinh, atan, atanh, ceil, cos, cosh,
exp, floor, imax, imin, interval, log, log10, Or, sin, sinh, sqrt,
tan, tanh,
)
np = import_module('numpy')
if not np:
disabled = True
#requires Numpy. Hence included in interval_functions
def test_interval_pow():
a = 2**interval(1, 2) == interval(2, 4)
assert a == (True, True)
a = interval(1, 2)**interval(1, 2) == interval(1, 4)
assert a == (True, True)
a = interval(-1, 1)**interval(0.5, 2)
assert a.is_valid is None
a = interval(-2, -1) ** interval(1, 2)
assert a.is_valid is False
a = interval(-2, -1) ** (1.0 / 2)
assert a.is_valid is False
a = interval(-1, 1)**(1.0 / 2)
assert a.is_valid is None
a = interval(-1, 1)**(1.0 / 3) == interval(-1, 1)
assert a == (True, True)
a = interval(-1, 1)**2 == interval(0, 1)
assert a == (True, True)
a = interval(-1, 1) ** (1.0 / 29) == interval(-1, 1)
assert a == (True, True)
a = -2**interval(1, 1) == interval(-2, -2)
assert a == (True, True)
a = interval(1, 2, is_valid=False)**2
assert a.is_valid is False
a = (-3)**interval(1, 2)
assert a.is_valid is False
a = (-4)**interval(0.5, 0.5)
assert a.is_valid is False
assert ((-3)**interval(1, 1) == interval(-3, -3)) == (True, True)
a = interval(8, 64)**(2.0 / 3)
assert abs(a.start - 4) < 1e-10 # eps
assert abs(a.end - 16) < 1e-10
a = interval(-8, 64)**(2.0 / 3)
assert abs(a.start - 4) < 1e-10 # eps
assert abs(a.end - 16) < 1e-10
def test_exp():
a = exp(interval(-np.inf, 0))
assert a.start == np.exp(-np.inf)
assert a.end == np.exp(0)
a = exp(interval(1, 2))
assert a.start == np.exp(1)
assert a.end == np.exp(2)
a = exp(1)
assert a.start == np.exp(1)
assert a.end == np.exp(1)
def test_log():
a = log(interval(1, 2))
assert a.start == 0
assert a.end == np.log(2)
a = log(interval(-1, 1))
assert a.is_valid is None
a = log(interval(-3, -1))
assert a.is_valid is False
a = log(-3)
assert a.is_valid is False
a = log(2)
assert a.start == np.log(2)
assert a.end == np.log(2)
def test_log10():
a = log10(interval(1, 2))
assert a.start == 0
assert a.end == np.log10(2)
a = log10(interval(-1, 1))
assert a.is_valid is None
a = log10(interval(-3, -1))
assert a.is_valid is False
a = log10(-3)
assert a.is_valid is False
a = log10(2)
assert a.start == np.log10(2)
assert a.end == np.log10(2)
def test_atan():
a = atan(interval(0, 1))
assert a.start == np.arctan(0)
assert a.end == np.arctan(1)
a = atan(1)
assert a.start == np.arctan(1)
assert a.end == np.arctan(1)
def test_sin():
a = sin(interval(0, np.pi / 4))
assert a.start == np.sin(0)
assert a.end == np.sin(np.pi / 4)
a = sin(interval(-np.pi / 4, np.pi / 4))
assert a.start == np.sin(-np.pi / 4)
assert a.end == np.sin(np.pi / 4)
a = sin(interval(np.pi / 4, 3 * np.pi / 4))
assert a.start == np.sin(np.pi / 4)
assert a.end == 1
a = sin(interval(7 * np.pi / 6, 7 * np.pi / 4))
assert a.start == -1
assert a.end == np.sin(7 * np.pi / 6)
a = sin(interval(0, 3 * np.pi))
assert a.start == -1
assert a.end == 1
a = sin(interval(np.pi / 3, 7 * np.pi / 4))
assert a.start == -1
assert a.end == 1
a = sin(np.pi / 4)
assert a.start == np.sin(np.pi / 4)
assert a.end == np.sin(np.pi / 4)
a = sin(interval(1, 2, is_valid=False))
assert a.is_valid is False
def test_cos():
a = cos(interval(0, np.pi / 4))
assert a.start == np.cos(np.pi / 4)
assert a.end == 1
a = cos(interval(-np.pi / 4, np.pi / 4))
assert a.start == np.cos(-np.pi / 4)
assert a.end == 1
a = cos(interval(np.pi / 4, 3 * np.pi / 4))
assert a.start == np.cos(3 * np.pi / 4)
assert a.end == np.cos(np.pi / 4)
a = cos(interval(3 * np.pi / 4, 5 * np.pi / 4))
assert a.start == -1
assert a.end == np.cos(3 * np.pi / 4)
a = cos(interval(0, 3 * np.pi))
assert a.start == -1
assert a.end == 1
a = cos(interval(- np.pi / 3, 5 * np.pi / 4))
assert a.start == -1
assert a.end == 1
a = cos(interval(1, 2, is_valid=False))
assert a.is_valid is False
def test_tan():
a = tan(interval(0, np.pi / 4))
assert a.start == 0
# must match lib_interval definition of tan:
assert a.end == np.sin(np.pi / 4)/np.cos(np.pi / 4)
a = tan(interval(np.pi / 4, 3 * np.pi / 4))
#discontinuity
assert a.is_valid is None
def test_sqrt():
a = sqrt(interval(1, 4))
assert a.start == 1
assert a.end == 2
a = sqrt(interval(0.01, 1))
assert a.start == np.sqrt(0.01)
assert a.end == 1
a = sqrt(interval(-1, 1))
assert a.is_valid is None
a = sqrt(interval(-3, -1))
assert a.is_valid is False
a = sqrt(4)
assert (a == interval(2, 2)) == (True, True)
a = sqrt(-3)
assert a.is_valid is False
def test_imin():
a = imin(interval(1, 3), interval(2, 5), interval(-1, 3))
assert a.start == -1
assert a.end == 3
a = imin(-2, interval(1, 4))
assert a.start == -2
assert a.end == -2
a = imin(5, interval(3, 4), interval(-2, 2, is_valid=False))
assert a.start == 3
assert a.end == 4
def test_imax():
a = imax(interval(-2, 2), interval(2, 7), interval(-3, 9))
assert a.start == 2
assert a.end == 9
a = imax(8, interval(1, 4))
assert a.start == 8
assert a.end == 8
a = imax(interval(1, 2), interval(3, 4), interval(-2, 2, is_valid=False))
assert a.start == 3
assert a.end == 4
def test_sinh():
a = sinh(interval(-1, 1))
assert a.start == np.sinh(-1)
assert a.end == np.sinh(1)
a = sinh(1)
assert a.start == np.sinh(1)
assert a.end == np.sinh(1)
def test_cosh():
a = cosh(interval(1, 2))
assert a.start == np.cosh(1)
assert a.end == np.cosh(2)
a = cosh(interval(-2, -1))
assert a.start == np.cosh(-1)
assert a.end == np.cosh(-2)
a = cosh(interval(-2, 1))
assert a.start == 1
assert a.end == np.cosh(-2)
a = cosh(1)
assert a.start == np.cosh(1)
assert a.end == np.cosh(1)
def test_tanh():
a = tanh(interval(-3, 3))
assert a.start == np.tanh(-3)
assert a.end == np.tanh(3)
a = tanh(3)
assert a.start == np.tanh(3)
assert a.end == np.tanh(3)
def test_asin():
a = asin(interval(-0.5, 0.5))
assert a.start == np.arcsin(-0.5)
assert a.end == np.arcsin(0.5)
a = asin(interval(-1.5, 1.5))
assert a.is_valid is None
a = asin(interval(-2, -1.5))
assert a.is_valid is False
a = asin(interval(0, 2))
assert a.is_valid is None
a = asin(interval(2, 5))
assert a.is_valid is False
a = asin(0.5)
assert a.start == np.arcsin(0.5)
assert a.end == np.arcsin(0.5)
a = asin(1.5)
assert a.is_valid is False
def test_acos():
a = acos(interval(-0.5, 0.5))
assert a.start == np.arccos(0.5)
assert a.end == np.arccos(-0.5)
a = acos(interval(-1.5, 1.5))
assert a.is_valid is None
a = acos(interval(-2, -1.5))
assert a.is_valid is False
a = acos(interval(0, 2))
assert a.is_valid is None
a = acos(interval(2, 5))
assert a.is_valid is False
a = acos(0.5)
assert a.start == np.arccos(0.5)
assert a.end == np.arccos(0.5)
a = acos(1.5)
assert a.is_valid is False
def test_ceil():
a = ceil(interval(0.2, 0.5))
assert a.start == 1
assert a.end == 1
a = ceil(interval(0.5, 1.5))
assert a.start == 1
assert a.end == 2
assert a.is_valid is None
a = ceil(interval(-5, 5))
assert a.is_valid is None
a = ceil(5.4)
assert a.start == 6
assert a.end == 6
def test_floor():
a = floor(interval(0.2, 0.5))
assert a.start == 0
assert a.end == 0
a = floor(interval(0.5, 1.5))
assert a.start == 0
assert a.end == 1
assert a.is_valid is None
a = floor(interval(-5, 5))
assert a.is_valid is None
a = floor(5.4)
assert a.start == 5
assert a.end == 5
def test_asinh():
a = asinh(interval(1, 2))
assert a.start == np.arcsinh(1)
assert a.end == np.arcsinh(2)
a = asinh(0.5)
assert a.start == np.arcsinh(0.5)
assert a.end == np.arcsinh(0.5)
def test_acosh():
a = acosh(interval(3, 5))
assert a.start == np.arccosh(3)
assert a.end == np.arccosh(5)
a = acosh(interval(0, 3))
assert a.is_valid is None
a = acosh(interval(-3, 0.5))
assert a.is_valid is False
a = acosh(0.5)
assert a.is_valid is False
a = acosh(2)
assert a.start == np.arccosh(2)
assert a.end == np.arccosh(2)
def test_atanh():
a = atanh(interval(-0.5, 0.5))
assert a.start == np.arctanh(-0.5)
assert a.end == np.arctanh(0.5)
a = atanh(interval(0, 3))
assert a.is_valid is None
a = atanh(interval(-3, -2))
assert a.is_valid is False
a = atanh(0.5)
assert a.start == np.arctanh(0.5)
assert a.end == np.arctanh(0.5)
a = atanh(1.5)
assert a.is_valid is False
def test_Abs():
assert (Abs(interval(-0.5, 0.5)) == interval(0, 0.5)) == (True, True)
assert (Abs(interval(-3, -2)) == interval(2, 3)) == (True, True)
assert (Abs(-3) == interval(3, 3)) == (True, True)
def test_And():
args = [(True, True), (True, False), (True, None)]
assert And(*args) == (True, False)
args = [(False, True), (None, None), (True, True)]
assert And(*args) == (False, None)
def test_Or():
args = [(True, True), (True, False), (False, None)]
assert Or(*args) == (True, True)
args = [(None, None), (False, None), (False, False)]
assert Or(*args) == (None, None)

View File

@ -0,0 +1,150 @@
from sympy.core.symbol import Symbol
from sympy.plotting.intervalmath import interval
from sympy.plotting.intervalmath.interval_membership import intervalMembership
from sympy.plotting.experimental_lambdify import experimental_lambdify
from sympy.testing.pytest import raises
def test_creation():
assert intervalMembership(True, True)
raises(TypeError, lambda: intervalMembership(True))
raises(TypeError, lambda: intervalMembership(True, True, True))
def test_getitem():
a = intervalMembership(True, False)
assert a[0] is True
assert a[1] is False
raises(IndexError, lambda: a[2])
def test_str():
a = intervalMembership(True, False)
assert str(a) == 'intervalMembership(True, False)'
assert repr(a) == 'intervalMembership(True, False)'
def test_equivalence():
a = intervalMembership(True, True)
b = intervalMembership(True, False)
assert (a == b) is False
assert (a != b) is True
a = intervalMembership(True, False)
b = intervalMembership(True, False)
assert (a == b) is True
assert (a != b) is False
def test_not():
x = Symbol('x')
r1 = x > -1
r2 = x <= -1
i = interval
f1 = experimental_lambdify((x,), r1)
f2 = experimental_lambdify((x,), r2)
tt = i(-0.1, 0.1, is_valid=True)
tn = i(-0.1, 0.1, is_valid=None)
tf = i(-0.1, 0.1, is_valid=False)
assert f1(tt) == ~f2(tt)
assert f1(tn) == ~f2(tn)
assert f1(tf) == ~f2(tf)
nt = i(0.9, 1.1, is_valid=True)
nn = i(0.9, 1.1, is_valid=None)
nf = i(0.9, 1.1, is_valid=False)
assert f1(nt) == ~f2(nt)
assert f1(nn) == ~f2(nn)
assert f1(nf) == ~f2(nf)
ft = i(1.9, 2.1, is_valid=True)
fn = i(1.9, 2.1, is_valid=None)
ff = i(1.9, 2.1, is_valid=False)
assert f1(ft) == ~f2(ft)
assert f1(fn) == ~f2(fn)
assert f1(ff) == ~f2(ff)
def test_boolean():
# There can be 9*9 test cases in full mapping of the cartesian product.
# But we only consider 3*3 cases for simplicity.
s = [
intervalMembership(False, False),
intervalMembership(None, None),
intervalMembership(True, True)
]
# Reduced tests for 'And'
a1 = [
intervalMembership(False, False),
intervalMembership(False, False),
intervalMembership(False, False),
intervalMembership(False, False),
intervalMembership(None, None),
intervalMembership(None, None),
intervalMembership(False, False),
intervalMembership(None, None),
intervalMembership(True, True)
]
a1_iter = iter(a1)
for i in range(len(s)):
for j in range(len(s)):
assert s[i] & s[j] == next(a1_iter)
# Reduced tests for 'Or'
a1 = [
intervalMembership(False, False),
intervalMembership(None, False),
intervalMembership(True, False),
intervalMembership(None, False),
intervalMembership(None, None),
intervalMembership(True, None),
intervalMembership(True, False),
intervalMembership(True, None),
intervalMembership(True, True)
]
a1_iter = iter(a1)
for i in range(len(s)):
for j in range(len(s)):
assert s[i] | s[j] == next(a1_iter)
# Reduced tests for 'Xor'
a1 = [
intervalMembership(False, False),
intervalMembership(None, False),
intervalMembership(True, False),
intervalMembership(None, False),
intervalMembership(None, None),
intervalMembership(None, None),
intervalMembership(True, False),
intervalMembership(None, None),
intervalMembership(False, True)
]
a1_iter = iter(a1)
for i in range(len(s)):
for j in range(len(s)):
assert s[i] ^ s[j] == next(a1_iter)
# Reduced tests for 'Not'
a1 = [
intervalMembership(True, False),
intervalMembership(None, None),
intervalMembership(False, True)
]
a1_iter = iter(a1)
for i in range(len(s)):
assert ~s[i] == next(a1_iter)
def test_boolean_errors():
a = intervalMembership(True, True)
raises(ValueError, lambda: a & 1)
raises(ValueError, lambda: a | 1)
raises(ValueError, lambda: a ^ 1)

View File

@ -0,0 +1,213 @@
from sympy.plotting.intervalmath import interval
from sympy.testing.pytest import raises
def test_interval():
assert (interval(1, 1) == interval(1, 1, is_valid=True)) == (True, True)
assert (interval(1, 1) == interval(1, 1, is_valid=False)) == (True, False)
assert (interval(1, 1) == interval(1, 1, is_valid=None)) == (True, None)
assert (interval(1, 1.5) == interval(1, 2)) == (None, True)
assert (interval(0, 1) == interval(2, 3)) == (False, True)
assert (interval(0, 1) == interval(1, 2)) == (None, True)
assert (interval(1, 2) != interval(1, 2)) == (False, True)
assert (interval(1, 3) != interval(2, 3)) == (None, True)
assert (interval(1, 3) != interval(-5, -3)) == (True, True)
assert (
interval(1, 3, is_valid=False) != interval(-5, -3)) == (True, False)
assert (interval(1, 3, is_valid=None) != interval(-5, 3)) == (None, None)
assert (interval(4, 4) != 4) == (False, True)
assert (interval(1, 1) == 1) == (True, True)
assert (interval(1, 3, is_valid=False) == interval(1, 3)) == (True, False)
assert (interval(1, 3, is_valid=None) == interval(1, 3)) == (True, None)
inter = interval(-5, 5)
assert (interval(inter) == interval(-5, 5)) == (True, True)
assert inter.width == 10
assert 0 in inter
assert -5 in inter
assert 5 in inter
assert interval(0, 3) in inter
assert interval(-6, 2) not in inter
assert -5.05 not in inter
assert 5.3 not in inter
interb = interval(-float('inf'), float('inf'))
assert 0 in inter
assert inter in interb
assert interval(0, float('inf')) in interb
assert interval(-float('inf'), 5) in interb
assert interval(-1e50, 1e50) in interb
assert (
-interval(-1, -2, is_valid=False) == interval(1, 2)) == (True, False)
raises(ValueError, lambda: interval(1, 2, 3))
def test_interval_add():
assert (interval(1, 2) + interval(2, 3) == interval(3, 5)) == (True, True)
assert (1 + interval(1, 2) == interval(2, 3)) == (True, True)
assert (interval(1, 2) + 1 == interval(2, 3)) == (True, True)
compare = (1 + interval(0, float('inf')) == interval(1, float('inf')))
assert compare == (True, True)
a = 1 + interval(2, 5, is_valid=False)
assert a.is_valid is False
a = 1 + interval(2, 5, is_valid=None)
assert a.is_valid is None
a = interval(2, 5, is_valid=False) + interval(3, 5, is_valid=None)
assert a.is_valid is False
a = interval(3, 5) + interval(-1, 1, is_valid=None)
assert a.is_valid is None
a = interval(2, 5, is_valid=False) + 1
assert a.is_valid is False
def test_interval_sub():
assert (interval(1, 2) - interval(1, 5) == interval(-4, 1)) == (True, True)
assert (interval(1, 2) - 1 == interval(0, 1)) == (True, True)
assert (1 - interval(1, 2) == interval(-1, 0)) == (True, True)
a = 1 - interval(1, 2, is_valid=False)
assert a.is_valid is False
a = interval(1, 4, is_valid=None) - 1
assert a.is_valid is None
a = interval(1, 3, is_valid=False) - interval(1, 3)
assert a.is_valid is False
a = interval(1, 3, is_valid=None) - interval(1, 3)
assert a.is_valid is None
def test_interval_inequality():
assert (interval(1, 2) < interval(3, 4)) == (True, True)
assert (interval(1, 2) < interval(2, 4)) == (None, True)
assert (interval(1, 2) < interval(-2, 0)) == (False, True)
assert (interval(1, 2) <= interval(2, 4)) == (True, True)
assert (interval(1, 2) <= interval(1.5, 6)) == (None, True)
assert (interval(2, 3) <= interval(1, 2)) == (None, True)
assert (interval(2, 3) <= interval(1, 1.5)) == (False, True)
assert (
interval(1, 2, is_valid=False) <= interval(-2, 0)) == (False, False)
assert (interval(1, 2, is_valid=None) <= interval(-2, 0)) == (False, None)
assert (interval(1, 2) <= 1.5) == (None, True)
assert (interval(1, 2) <= 3) == (True, True)
assert (interval(1, 2) <= 0) == (False, True)
assert (interval(5, 8) > interval(2, 3)) == (True, True)
assert (interval(2, 5) > interval(1, 3)) == (None, True)
assert (interval(2, 3) > interval(3.1, 5)) == (False, True)
assert (interval(-1, 1) == 0) == (None, True)
assert (interval(-1, 1) == 2) == (False, True)
assert (interval(-1, 1) != 0) == (None, True)
assert (interval(-1, 1) != 2) == (True, True)
assert (interval(3, 5) > 2) == (True, True)
assert (interval(3, 5) < 2) == (False, True)
assert (interval(1, 5) < 2) == (None, True)
assert (interval(1, 5) > 2) == (None, True)
assert (interval(0, 1) > 2) == (False, True)
assert (interval(1, 2) >= interval(0, 1)) == (True, True)
assert (interval(1, 2) >= interval(0, 1.5)) == (None, True)
assert (interval(1, 2) >= interval(3, 4)) == (False, True)
assert (interval(1, 2) >= 0) == (True, True)
assert (interval(1, 2) >= 1.2) == (None, True)
assert (interval(1, 2) >= 3) == (False, True)
assert (2 > interval(0, 1)) == (True, True)
a = interval(-1, 1, is_valid=False) < interval(2, 5, is_valid=None)
assert a == (True, False)
a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=False)
assert a == (True, False)
a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=None)
assert a == (True, None)
a = interval(-1, 1, is_valid=False) > interval(-5, -2, is_valid=None)
assert a == (True, False)
a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=False)
assert a == (True, False)
a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=None)
assert a == (True, None)
def test_interval_mul():
assert (
interval(1, 5) * interval(2, 10) == interval(2, 50)) == (True, True)
a = interval(-1, 1) * interval(2, 10) == interval(-10, 10)
assert a == (True, True)
a = interval(-1, 1) * interval(-5, 3) == interval(-5, 5)
assert a == (True, True)
assert (interval(1, 3) * 2 == interval(2, 6)) == (True, True)
assert (3 * interval(-1, 2) == interval(-3, 6)) == (True, True)
a = 3 * interval(1, 2, is_valid=False)
assert a.is_valid is False
a = 3 * interval(1, 2, is_valid=None)
assert a.is_valid is None
a = interval(1, 5, is_valid=False) * interval(1, 2, is_valid=None)
assert a.is_valid is False
def test_interval_div():
div = interval(1, 2, is_valid=False) / 3
assert div == interval(-float('inf'), float('inf'), is_valid=False)
div = interval(1, 2, is_valid=None) / 3
assert div == interval(-float('inf'), float('inf'), is_valid=None)
div = 3 / interval(1, 2, is_valid=None)
assert div == interval(-float('inf'), float('inf'), is_valid=None)
a = interval(1, 2) / 0
assert a.is_valid is False
a = interval(0.5, 1) / interval(-1, 0)
assert a.is_valid is None
a = interval(0, 1) / interval(0, 1)
assert a.is_valid is None
a = interval(-1, 1) / interval(-1, 1)
assert a.is_valid is None
a = interval(-1, 2) / interval(0.5, 1) == interval(-2.0, 4.0)
assert a == (True, True)
a = interval(0, 1) / interval(0.5, 1) == interval(0.0, 2.0)
assert a == (True, True)
a = interval(-1, 0) / interval(0.5, 1) == interval(-2.0, 0.0)
assert a == (True, True)
a = interval(-0.5, -0.25) / interval(0.5, 1) == interval(-1.0, -0.25)
assert a == (True, True)
a = interval(0.5, 1) / interval(0.5, 1) == interval(0.5, 2.0)
assert a == (True, True)
a = interval(0.5, 4) / interval(0.5, 1) == interval(0.5, 8.0)
assert a == (True, True)
a = interval(-1, -0.5) / interval(0.5, 1) == interval(-2.0, -0.5)
assert a == (True, True)
a = interval(-4, -0.5) / interval(0.5, 1) == interval(-8.0, -0.5)
assert a == (True, True)
a = interval(-1, 2) / interval(-2, -0.5) == interval(-4.0, 2.0)
assert a == (True, True)
a = interval(0, 1) / interval(-2, -0.5) == interval(-2.0, 0.0)
assert a == (True, True)
a = interval(-1, 0) / interval(-2, -0.5) == interval(0.0, 2.0)
assert a == (True, True)
a = interval(-0.5, -0.25) / interval(-2, -0.5) == interval(0.125, 1.0)
assert a == (True, True)
a = interval(0.5, 1) / interval(-2, -0.5) == interval(-2.0, -0.25)
assert a == (True, True)
a = interval(0.5, 4) / interval(-2, -0.5) == interval(-8.0, -0.25)
assert a == (True, True)
a = interval(-1, -0.5) / interval(-2, -0.5) == interval(0.25, 2.0)
assert a == (True, True)
a = interval(-4, -0.5) / interval(-2, -0.5) == interval(0.25, 8.0)
assert a == (True, True)
a = interval(-5, 5, is_valid=False) / 2
assert a.is_valid is False
def test_hashable():
'''
test that interval objects are hashable.
this is required in order to be able to put them into the cache, which
appears to be necessary for plotting in py3k. For details, see:
https://github.com/sympy/sympy/pull/2101
https://github.com/sympy/sympy/issues/6533
'''
hash(interval(1, 1))
hash(interval(1, 1, is_valid=True))
hash(interval(-4, -0.5))
hash(interval(-2, -0.5))
hash(interval(0.25, 8.0))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,233 @@
"""Implicit plotting module for SymPy.
Explanation
===========
The module implements a data series called ImplicitSeries which is used by
``Plot`` class to plot implicit plots for different backends. The module,
by default, implements plotting using interval arithmetic. It switches to a
fall back algorithm if the expression cannot be plotted using interval arithmetic.
It is also possible to specify to use the fall back algorithm for all plots.
Boolean combinations of expressions cannot be plotted by the fall back
algorithm.
See Also
========
sympy.plotting.plot
References
==========
.. [1] Jeffrey Allen Tupper. Reliable Two-Dimensional Graphing Methods for
Mathematical Formulae with Two Free Variables.
.. [2] Jeffrey Allen Tupper. Graphing Equations with Generalized Interval
Arithmetic. Master's thesis. University of Toronto, 1996
"""
from sympy.core.containers import Tuple
from sympy.core.symbol import (Dummy, Symbol)
from sympy.polys.polyutils import _sort_gens
from sympy.plotting.series import ImplicitSeries, _set_discretization_points
from sympy.plotting.plot import plot_factory
from sympy.utilities.decorator import doctest_depends_on
from sympy.utilities.iterables import flatten
__doctest_requires__ = {'plot_implicit': ['matplotlib']}
@doctest_depends_on(modules=('matplotlib',))
def plot_implicit(expr, x_var=None, y_var=None, adaptive=True, depth=0,
n=300, line_color="blue", show=True, **kwargs):
"""A plot function to plot implicit equations / inequalities.
Arguments
=========
- expr : The equation / inequality that is to be plotted.
- x_var (optional) : symbol to plot on x-axis or tuple giving symbol
and range as ``(symbol, xmin, xmax)``
- y_var (optional) : symbol to plot on y-axis or tuple giving symbol
and range as ``(symbol, ymin, ymax)``
If neither ``x_var`` nor ``y_var`` are given then the free symbols in the
expression will be assigned in the order they are sorted.
The following keyword arguments can also be used:
- ``adaptive`` Boolean. The default value is set to True. It has to be
set to False if you want to use a mesh grid.
- ``depth`` integer. The depth of recursion for adaptive mesh grid.
Default value is 0. Takes value in the range (0, 4).
- ``n`` integer. The number of points if adaptive mesh grid is not
used. Default value is 300. This keyword argument replaces ``points``,
which should be considered deprecated.
- ``show`` Boolean. Default value is True. If set to False, the plot will
not be shown. See ``Plot`` for further information.
- ``title`` string. The title for the plot.
- ``xlabel`` string. The label for the x-axis
- ``ylabel`` string. The label for the y-axis
Aesthetics options:
- ``line_color``: float or string. Specifies the color for the plot.
See ``Plot`` to see how to set color for the plots.
Default value is "Blue"
plot_implicit, by default, uses interval arithmetic to plot functions. If
the expression cannot be plotted using interval arithmetic, it defaults to
a generating a contour using a mesh grid of fixed number of points. By
setting adaptive to False, you can force plot_implicit to use the mesh
grid. The mesh grid method can be effective when adaptive plotting using
interval arithmetic, fails to plot with small line width.
Examples
========
Plot expressions:
.. plot::
:context: reset
:format: doctest
:include-source: True
>>> from sympy import plot_implicit, symbols, Eq, And
>>> x, y = symbols('x y')
Without any ranges for the symbols in the expression:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p1 = plot_implicit(Eq(x**2 + y**2, 5))
With the range for the symbols:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p2 = plot_implicit(
... Eq(x**2 + y**2, 3), (x, -3, 3), (y, -3, 3))
With depth of recursion as argument:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p3 = plot_implicit(
... Eq(x**2 + y**2, 5), (x, -4, 4), (y, -4, 4), depth = 2)
Using mesh grid and not using adaptive meshing:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p4 = plot_implicit(
... Eq(x**2 + y**2, 5), (x, -5, 5), (y, -2, 2),
... adaptive=False)
Using mesh grid without using adaptive meshing with number of points
specified:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p5 = plot_implicit(
... Eq(x**2 + y**2, 5), (x, -5, 5), (y, -2, 2),
... adaptive=False, n=400)
Plotting regions:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p6 = plot_implicit(y > x**2)
Plotting Using boolean conjunctions:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p7 = plot_implicit(And(y > x, y > -x))
When plotting an expression with a single variable (y - 1, for example),
specify the x or the y variable explicitly:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> p8 = plot_implicit(y - 1, y_var=y)
>>> p9 = plot_implicit(x - 1, x_var=x)
"""
xyvar = [i for i in (x_var, y_var) if i is not None]
free_symbols = expr.free_symbols
range_symbols = Tuple(*flatten(xyvar)).free_symbols
undeclared = free_symbols - range_symbols
if len(free_symbols & range_symbols) > 2:
raise NotImplementedError("Implicit plotting is not implemented for "
"more than 2 variables")
#Create default ranges if the range is not provided.
default_range = Tuple(-5, 5)
def _range_tuple(s):
if isinstance(s, Symbol):
return Tuple(s) + default_range
if len(s) == 3:
return Tuple(*s)
raise ValueError('symbol or `(symbol, min, max)` expected but got %s' % s)
if len(xyvar) == 0:
xyvar = list(_sort_gens(free_symbols))
var_start_end_x = _range_tuple(xyvar[0])
x = var_start_end_x[0]
if len(xyvar) != 2:
if x in undeclared or not undeclared:
xyvar.append(Dummy('f(%s)' % x.name))
else:
xyvar.append(undeclared.pop())
var_start_end_y = _range_tuple(xyvar[1])
kwargs = _set_discretization_points(kwargs, ImplicitSeries)
series_argument = ImplicitSeries(
expr, var_start_end_x, var_start_end_y,
adaptive=adaptive, depth=depth,
n=n, line_color=line_color)
#set the x and y limits
kwargs['xlim'] = tuple(float(x) for x in var_start_end_x[1:])
kwargs['ylim'] = tuple(float(y) for y in var_start_end_y[1:])
# set the x and y labels
kwargs.setdefault('xlabel', var_start_end_x[0])
kwargs.setdefault('ylabel', var_start_end_y[0])
p = plot_factory(series_argument, **kwargs)
if show:
p.show()
return p

View File

@ -0,0 +1,188 @@
from sympy.external import import_module
import sympy.plotting.backends.base_backend as base_backend
# N.B.
# When changing the minimum module version for matplotlib, please change
# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
__doctest_requires__ = {
("PlotGrid",): ["matplotlib"],
}
class PlotGrid:
"""This class helps to plot subplots from already created SymPy plots
in a single figure.
Examples
========
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> from sympy import symbols
>>> from sympy.plotting import plot, plot3d, PlotGrid
>>> x, y = symbols('x, y')
>>> p1 = plot(x, x**2, x**3, (x, -5, 5))
>>> p2 = plot((x**2, (x, -6, 6)), (x, (x, -5, 5)))
>>> p3 = plot(x**3, (x, -5, 5))
>>> p4 = plot3d(x*y, (x, -5, 5), (y, -5, 5))
Plotting vertically in a single line:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> PlotGrid(2, 1, p1, p2)
PlotGrid object containing:
Plot[0]:Plot object containing:
[0]: cartesian line: x for x over (-5.0, 5.0)
[1]: cartesian line: x**2 for x over (-5.0, 5.0)
[2]: cartesian line: x**3 for x over (-5.0, 5.0)
Plot[1]:Plot object containing:
[0]: cartesian line: x**2 for x over (-6.0, 6.0)
[1]: cartesian line: x for x over (-5.0, 5.0)
Plotting horizontally in a single line:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> PlotGrid(1, 3, p2, p3, p4)
PlotGrid object containing:
Plot[0]:Plot object containing:
[0]: cartesian line: x**2 for x over (-6.0, 6.0)
[1]: cartesian line: x for x over (-5.0, 5.0)
Plot[1]:Plot object containing:
[0]: cartesian line: x**3 for x over (-5.0, 5.0)
Plot[2]:Plot object containing:
[0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
Plotting in a grid form:
.. plot::
:context: close-figs
:format: doctest
:include-source: True
>>> PlotGrid(2, 2, p1, p2, p3, p4)
PlotGrid object containing:
Plot[0]:Plot object containing:
[0]: cartesian line: x for x over (-5.0, 5.0)
[1]: cartesian line: x**2 for x over (-5.0, 5.0)
[2]: cartesian line: x**3 for x over (-5.0, 5.0)
Plot[1]:Plot object containing:
[0]: cartesian line: x**2 for x over (-6.0, 6.0)
[1]: cartesian line: x for x over (-5.0, 5.0)
Plot[2]:Plot object containing:
[0]: cartesian line: x**3 for x over (-5.0, 5.0)
Plot[3]:Plot object containing:
[0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
"""
def __init__(self, nrows, ncolumns, *args, show=True, size=None, **kwargs):
"""
Parameters
==========
nrows :
The number of rows that should be in the grid of the
required subplot.
ncolumns :
The number of columns that should be in the grid
of the required subplot.
nrows and ncolumns together define the required grid.
Arguments
=========
A list of predefined plot objects entered in a row-wise sequence
i.e. plot objects which are to be in the top row of the required
grid are written first, then the second row objects and so on
Keyword arguments
=================
show : Boolean
The default value is set to ``True``. Set show to ``False`` and
the function will not display the subplot. The returned instance
of the ``PlotGrid`` class can then be used to save or display the
plot by calling the ``save()`` and ``show()`` methods
respectively.
size : (float, float), optional
A tuple in the form (width, height) in inches to specify the size of
the overall figure. The default value is set to ``None``, meaning
the size will be set by the default backend.
"""
self.matplotlib = import_module('matplotlib',
import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
min_module_version='1.1.0', catch=(RuntimeError,))
self.nrows = nrows
self.ncolumns = ncolumns
self._series = []
self._fig = None
self.args = args
for arg in args:
self._series.append(arg._series)
self.size = size
if show and self.matplotlib:
self.show()
def _create_figure(self):
gs = self.matplotlib.gridspec.GridSpec(self.nrows, self.ncolumns)
mapping = {}
c = 0
for i in range(self.nrows):
for j in range(self.ncolumns):
if c < len(self.args):
mapping[gs[i, j]] = self.args[c]
c += 1
kw = {} if not self.size else {"figsize": self.size}
self._fig = self.matplotlib.pyplot.figure(**kw)
for spec, p in mapping.items():
kw = ({"projection": "3d"} if (len(p._series) > 0 and
p._series[0].is_3D) else {})
cur_ax = self._fig.add_subplot(spec, **kw)
p._plotgrid_fig = self._fig
p._plotgrid_ax = cur_ax
p.process_series()
@property
def fig(self):
if not self._fig:
self._create_figure()
return self._fig
@property
def _backend(self):
return self
def close(self):
self.matplotlib.pyplot.close(self.fig)
def show(self):
if base_backend._show:
self.fig.tight_layout()
self.matplotlib.pyplot.show()
else:
self.close()
def save(self, path):
self.fig.savefig(path)
def __str__(self):
plot_strs = [('Plot[%d]:' % i) + str(plot)
for i, plot in enumerate(self.args)]
return 'PlotGrid object containing:\n' + '\n'.join(plot_strs)

View File

@ -0,0 +1,138 @@
"""Plotting module that can plot 2D and 3D functions
"""
from sympy.utilities.decorator import doctest_depends_on
@doctest_depends_on(modules=('pyglet',))
def PygletPlot(*args, **kwargs):
"""
Plot Examples
=============
See examples/advanced/pyglet_plotting.py for many more examples.
>>> from sympy.plotting.pygletplot import PygletPlot as Plot
>>> from sympy.abc import x, y, z
>>> Plot(x*y**3-y*x**3)
[0]: -x**3*y + x*y**3, 'mode=cartesian'
>>> p = Plot()
>>> p[1] = x*y
>>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4)
>>> p = Plot()
>>> p[1] = x**2+y**2
>>> p[2] = -x**2-y**2
Variable Intervals
==================
The basic format is [var, min, max, steps], but the
syntax is flexible and arguments left out are taken
from the defaults for the current coordinate mode:
>>> Plot(x**2) # implies [x,-5,5,100]
[0]: x**2, 'mode=cartesian'
>>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40]
[0]: x**2, 'mode=cartesian'
>>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100]
[0]: x**2 - y**2, 'mode=cartesian'
>>> Plot(x**2, [x,-13,13,100])
[0]: x**2, 'mode=cartesian'
>>> Plot(x**2, [-13,13]) # [x,-13,13,100]
[0]: x**2, 'mode=cartesian'
>>> Plot(x**2, [x,-13,13]) # [x,-13,13,100]
[0]: x**2, 'mode=cartesian'
>>> Plot(1*x, [], [x], mode='cylindrical')
... # [unbound_theta,0,2*Pi,40], [x,-1,1,20]
[0]: x, 'mode=cartesian'
Coordinate Modes
================
Plot supports several curvilinear coordinate modes, and
they independent for each plotted function. You can specify
a coordinate mode explicitly with the 'mode' named argument,
but it can be automatically determined for Cartesian or
parametric plots, and therefore must only be specified for
polar, cylindrical, and spherical modes.
Specifically, Plot(function arguments) and Plot[n] =
(function arguments) will interpret your arguments as a
Cartesian plot if you provide one function and a parametric
plot if you provide two or three functions. Similarly, the
arguments will be interpreted as a curve if one variable is
used, and a surface if two are used.
Supported mode names by number of variables:
1: parametric, cartesian, polar
2: parametric, cartesian, cylindrical = polar, spherical
>>> Plot(1, mode='spherical')
Calculator-like Interface
=========================
>>> p = Plot(visible=False)
>>> f = x**2
>>> p[1] = f
>>> p[2] = f.diff(x)
>>> p[3] = f.diff(x).diff(x)
>>> p
[1]: x**2, 'mode=cartesian'
[2]: 2*x, 'mode=cartesian'
[3]: 2, 'mode=cartesian'
>>> p.show()
>>> p.clear()
>>> p
<blank plot>
>>> p[1] = x**2+y**2
>>> p[1].style = 'solid'
>>> p[2] = -x**2-y**2
>>> p[2].style = 'wireframe'
>>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4)
>>> p[1].style = 'both'
>>> p[2].style = 'both'
>>> p.close()
Plot Window Keyboard Controls
=============================
Screen Rotation:
X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2
Z axis Q,E, Numpad 7,9
Model Rotation:
Z axis Z,C, Numpad 1,3
Zoom: R,F, PgUp,PgDn, Numpad +,-
Reset Camera: X, Numpad 5
Camera Presets:
XY F1
XZ F2
YZ F3
Perspective F4
Sensitivity Modifier: SHIFT
Axes Toggle:
Visible F5
Colors F6
Close Window: ESCAPE
=============================
"""
from sympy.plotting.pygletplot.plot import PygletPlot
return PygletPlot(*args, **kwargs)

View File

@ -0,0 +1,336 @@
from sympy.core.basic import Basic
from sympy.core.symbol import (Symbol, symbols)
from sympy.utilities.lambdify import lambdify
from .util import interpolate, rinterpolate, create_bounds, update_bounds
from sympy.utilities.iterables import sift
class ColorGradient:
colors = [0.4, 0.4, 0.4], [0.9, 0.9, 0.9]
intervals = 0.0, 1.0
def __init__(self, *args):
if len(args) == 2:
self.colors = list(args)
self.intervals = [0.0, 1.0]
elif len(args) > 0:
if len(args) % 2 != 0:
raise ValueError("len(args) should be even")
self.colors = [args[i] for i in range(1, len(args), 2)]
self.intervals = [args[i] for i in range(0, len(args), 2)]
assert len(self.colors) == len(self.intervals)
def copy(self):
c = ColorGradient()
c.colors = [e[::] for e in self.colors]
c.intervals = self.intervals[::]
return c
def _find_interval(self, v):
m = len(self.intervals)
i = 0
while i < m - 1 and self.intervals[i] <= v:
i += 1
return i
def _interpolate_axis(self, axis, v):
i = self._find_interval(v)
v = rinterpolate(self.intervals[i - 1], self.intervals[i], v)
return interpolate(self.colors[i - 1][axis], self.colors[i][axis], v)
def __call__(self, r, g, b):
c = self._interpolate_axis
return c(0, r), c(1, g), c(2, b)
default_color_schemes = {} # defined at the bottom of this file
class ColorScheme:
def __init__(self, *args, **kwargs):
self.args = args
self.f, self.gradient = None, ColorGradient()
if len(args) == 1 and not isinstance(args[0], Basic) and callable(args[0]):
self.f = args[0]
elif len(args) == 1 and isinstance(args[0], str):
if args[0] in default_color_schemes:
cs = default_color_schemes[args[0]]
self.f, self.gradient = cs.f, cs.gradient.copy()
else:
self.f = lambdify('x,y,z,u,v', args[0])
else:
self.f, self.gradient = self._interpret_args(args)
self._test_color_function()
if not isinstance(self.gradient, ColorGradient):
raise ValueError("Color gradient not properly initialized. "
"(Not a ColorGradient instance.)")
def _interpret_args(self, args):
f, gradient = None, self.gradient
atoms, lists = self._sort_args(args)
s = self._pop_symbol_list(lists)
s = self._fill_in_vars(s)
# prepare the error message for lambdification failure
f_str = ', '.join(str(fa) for fa in atoms)
s_str = (str(sa) for sa in s)
s_str = ', '.join(sa for sa in s_str if sa.find('unbound') < 0)
f_error = ValueError("Could not interpret arguments "
"%s as functions of %s." % (f_str, s_str))
# try to lambdify args
if len(atoms) == 1:
fv = atoms[0]
try:
f = lambdify(s, [fv, fv, fv])
except TypeError:
raise f_error
elif len(atoms) == 3:
fr, fg, fb = atoms
try:
f = lambdify(s, [fr, fg, fb])
except TypeError:
raise f_error
else:
raise ValueError("A ColorScheme must provide 1 or 3 "
"functions in x, y, z, u, and/or v.")
# try to intrepret any given color information
if len(lists) == 0:
gargs = []
elif len(lists) == 1:
gargs = lists[0]
elif len(lists) == 2:
try:
(r1, g1, b1), (r2, g2, b2) = lists
except TypeError:
raise ValueError("If two color arguments are given, "
"they must be given in the format "
"(r1, g1, b1), (r2, g2, b2).")
gargs = lists
elif len(lists) == 3:
try:
(r1, r2), (g1, g2), (b1, b2) = lists
except Exception:
raise ValueError("If three color arguments are given, "
"they must be given in the format "
"(r1, r2), (g1, g2), (b1, b2). To create "
"a multi-step gradient, use the syntax "
"[0, colorStart, step1, color1, ..., 1, "
"colorEnd].")
gargs = [[r1, g1, b1], [r2, g2, b2]]
else:
raise ValueError("Don't know what to do with collection "
"arguments %s." % (', '.join(str(l) for l in lists)))
if gargs:
try:
gradient = ColorGradient(*gargs)
except Exception as ex:
raise ValueError(("Could not initialize a gradient "
"with arguments %s. Inner "
"exception: %s") % (gargs, str(ex)))
return f, gradient
def _pop_symbol_list(self, lists):
symbol_lists = []
for l in lists:
mark = True
for s in l:
if s is not None and not isinstance(s, Symbol):
mark = False
break
if mark:
lists.remove(l)
symbol_lists.append(l)
if len(symbol_lists) == 1:
return symbol_lists[0]
elif len(symbol_lists) == 0:
return []
else:
raise ValueError("Only one list of Symbols "
"can be given for a color scheme.")
def _fill_in_vars(self, args):
defaults = symbols('x,y,z,u,v')
v_error = ValueError("Could not find what to plot.")
if len(args) == 0:
return defaults
if not isinstance(args, (tuple, list)):
raise v_error
if len(args) == 0:
return defaults
for s in args:
if s is not None and not isinstance(s, Symbol):
raise v_error
# when vars are given explicitly, any vars
# not given are marked 'unbound' as to not
# be accidentally used in an expression
vars = [Symbol('unbound%i' % (i)) for i in range(1, 6)]
# interpret as t
if len(args) == 1:
vars[3] = args[0]
# interpret as u,v
elif len(args) == 2:
if args[0] is not None:
vars[3] = args[0]
if args[1] is not None:
vars[4] = args[1]
# interpret as x,y,z
elif len(args) >= 3:
# allow some of x,y,z to be
# left unbound if not given
if args[0] is not None:
vars[0] = args[0]
if args[1] is not None:
vars[1] = args[1]
if args[2] is not None:
vars[2] = args[2]
# interpret the rest as t
if len(args) >= 4:
vars[3] = args[3]
# ...or u,v
if len(args) >= 5:
vars[4] = args[4]
return vars
def _sort_args(self, args):
lists, atoms = sift(args,
lambda a: isinstance(a, (tuple, list)), binary=True)
return atoms, lists
def _test_color_function(self):
if not callable(self.f):
raise ValueError("Color function is not callable.")
try:
result = self.f(0, 0, 0, 0, 0)
if len(result) != 3:
raise ValueError("length should be equal to 3")
except TypeError:
raise ValueError("Color function needs to accept x,y,z,u,v, "
"as arguments even if it doesn't use all of them.")
except AssertionError:
raise ValueError("Color function needs to return 3-tuple r,g,b.")
except Exception:
pass # color function probably not valid at 0,0,0,0,0
def __call__(self, x, y, z, u, v):
try:
return self.f(x, y, z, u, v)
except Exception:
return None
def apply_to_curve(self, verts, u_set, set_len=None, inc_pos=None):
"""
Apply this color scheme to a
set of vertices over a single
independent variable u.
"""
bounds = create_bounds()
cverts = []
if callable(set_len):
set_len(len(u_set)*2)
# calculate f() = r,g,b for each vert
# and find the min and max for r,g,b
for _u in range(len(u_set)):
if verts[_u] is None:
cverts.append(None)
else:
x, y, z = verts[_u]
u, v = u_set[_u], None
c = self(x, y, z, u, v)
if c is not None:
c = list(c)
update_bounds(bounds, c)
cverts.append(c)
if callable(inc_pos):
inc_pos()
# scale and apply gradient
for _u in range(len(u_set)):
if cverts[_u] is not None:
for _c in range(3):
# scale from [f_min, f_max] to [0,1]
cverts[_u][_c] = rinterpolate(bounds[_c][0], bounds[_c][1],
cverts[_u][_c])
# apply gradient
cverts[_u] = self.gradient(*cverts[_u])
if callable(inc_pos):
inc_pos()
return cverts
def apply_to_surface(self, verts, u_set, v_set, set_len=None, inc_pos=None):
"""
Apply this color scheme to a
set of vertices over two
independent variables u and v.
"""
bounds = create_bounds()
cverts = []
if callable(set_len):
set_len(len(u_set)*len(v_set)*2)
# calculate f() = r,g,b for each vert
# and find the min and max for r,g,b
for _u in range(len(u_set)):
column = []
for _v in range(len(v_set)):
if verts[_u][_v] is None:
column.append(None)
else:
x, y, z = verts[_u][_v]
u, v = u_set[_u], v_set[_v]
c = self(x, y, z, u, v)
if c is not None:
c = list(c)
update_bounds(bounds, c)
column.append(c)
if callable(inc_pos):
inc_pos()
cverts.append(column)
# scale and apply gradient
for _u in range(len(u_set)):
for _v in range(len(v_set)):
if cverts[_u][_v] is not None:
# scale from [f_min, f_max] to [0,1]
for _c in range(3):
cverts[_u][_v][_c] = rinterpolate(bounds[_c][0],
bounds[_c][1], cverts[_u][_v][_c])
# apply gradient
cverts[_u][_v] = self.gradient(*cverts[_u][_v])
if callable(inc_pos):
inc_pos()
return cverts
def str_base(self):
return ", ".join(str(a) for a in self.args)
def __repr__(self):
return "%s" % (self.str_base())
x, y, z, t, u, v = symbols('x,y,z,t,u,v')
default_color_schemes['rainbow'] = ColorScheme(z, y, x)
default_color_schemes['zfade'] = ColorScheme(z, (0.4, 0.4, 0.97),
(0.97, 0.4, 0.4), (None, None, z))
default_color_schemes['zfade3'] = ColorScheme(z, (None, None, z),
[0.00, (0.2, 0.2, 1.0),
0.35, (0.2, 0.8, 0.4),
0.50, (0.3, 0.9, 0.3),
0.65, (0.4, 0.8, 0.2),
1.00, (1.0, 0.2, 0.2)])
default_color_schemes['zfade4'] = ColorScheme(z, (None, None, z),
[0.0, (0.3, 0.3, 1.0),
0.30, (0.3, 1.0, 0.3),
0.55, (0.95, 1.0, 0.2),
0.65, (1.0, 0.95, 0.2),
0.85, (1.0, 0.7, 0.2),
1.0, (1.0, 0.3, 0.2)])

View File

@ -0,0 +1,106 @@
from pyglet.window import Window
from pyglet.clock import Clock
from threading import Thread, Lock
gl_lock = Lock()
class ManagedWindow(Window):
"""
A pyglet window with an event loop which executes automatically
in a separate thread. Behavior is added by creating a subclass
which overrides setup, update, and/or draw.
"""
fps_limit = 30
default_win_args = {"width": 600,
"height": 500,
"vsync": False,
"resizable": True}
def __init__(self, **win_args):
"""
It is best not to override this function in the child
class, unless you need to take additional arguments.
Do any OpenGL initialization calls in setup().
"""
# check if this is run from the doctester
if win_args.get('runfromdoctester', False):
return
self.win_args = dict(self.default_win_args, **win_args)
self.Thread = Thread(target=self.__event_loop__)
self.Thread.start()
def __event_loop__(self, **win_args):
"""
The event loop thread function. Do not override or call
directly (it is called by __init__).
"""
gl_lock.acquire()
try:
try:
super().__init__(**self.win_args)
self.switch_to()
self.setup()
except Exception as e:
print("Window initialization failed: %s" % (str(e)))
self.has_exit = True
finally:
gl_lock.release()
clock = Clock()
clock.fps_limit = self.fps_limit
while not self.has_exit:
dt = clock.tick()
gl_lock.acquire()
try:
try:
self.switch_to()
self.dispatch_events()
self.clear()
self.update(dt)
self.draw()
self.flip()
except Exception as e:
print("Uncaught exception in event loop: %s" % str(e))
self.has_exit = True
finally:
gl_lock.release()
super().close()
def close(self):
"""
Closes the window.
"""
self.has_exit = True
def setup(self):
"""
Called once before the event loop begins.
Override this method in a child class. This
is the best place to put things like OpenGL
initialization calls.
"""
pass
def update(self, dt):
"""
Called before draw during each iteration of
the event loop. dt is the elapsed time in
seconds since the last update. OpenGL rendering
calls are best put in draw() rather than here.
"""
pass
def draw(self):
"""
Called after update during each iteration of
the event loop. Put OpenGL rendering calls
here.
"""
pass
if __name__ == '__main__':
ManagedWindow()

View File

@ -0,0 +1,464 @@
from threading import RLock
# it is sufficient to import "pyglet" here once
try:
import pyglet.gl as pgl
except ImportError:
raise ImportError("pyglet is required for plotting.\n "
"visit https://pyglet.org/")
from sympy.core.numbers import Integer
from sympy.external.gmpy import SYMPY_INTS
from sympy.geometry.entity import GeometryEntity
from sympy.plotting.pygletplot.plot_axes import PlotAxes
from sympy.plotting.pygletplot.plot_mode import PlotMode
from sympy.plotting.pygletplot.plot_object import PlotObject
from sympy.plotting.pygletplot.plot_window import PlotWindow
from sympy.plotting.pygletplot.util import parse_option_string
from sympy.utilities.decorator import doctest_depends_on
from sympy.utilities.iterables import is_sequence
from time import sleep
from os import getcwd, listdir
import ctypes
@doctest_depends_on(modules=('pyglet',))
class PygletPlot:
"""
Plot Examples
=============
See examples/advanced/pyglet_plotting.py for many more examples.
>>> from sympy.plotting.pygletplot import PygletPlot as Plot
>>> from sympy.abc import x, y, z
>>> Plot(x*y**3-y*x**3)
[0]: -x**3*y + x*y**3, 'mode=cartesian'
>>> p = Plot()
>>> p[1] = x*y
>>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4)
>>> p = Plot()
>>> p[1] = x**2+y**2
>>> p[2] = -x**2-y**2
Variable Intervals
==================
The basic format is [var, min, max, steps], but the
syntax is flexible and arguments left out are taken
from the defaults for the current coordinate mode:
>>> Plot(x**2) # implies [x,-5,5,100]
[0]: x**2, 'mode=cartesian'
>>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40]
[0]: x**2, 'mode=cartesian'
>>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100]
[0]: x**2 - y**2, 'mode=cartesian'
>>> Plot(x**2, [x,-13,13,100])
[0]: x**2, 'mode=cartesian'
>>> Plot(x**2, [-13,13]) # [x,-13,13,100]
[0]: x**2, 'mode=cartesian'
>>> Plot(x**2, [x,-13,13]) # [x,-13,13,10]
[0]: x**2, 'mode=cartesian'
>>> Plot(1*x, [], [x], mode='cylindrical')
... # [unbound_theta,0,2*Pi,40], [x,-1,1,20]
[0]: x, 'mode=cartesian'
Coordinate Modes
================
Plot supports several curvilinear coordinate modes, and
they independent for each plotted function. You can specify
a coordinate mode explicitly with the 'mode' named argument,
but it can be automatically determined for Cartesian or
parametric plots, and therefore must only be specified for
polar, cylindrical, and spherical modes.
Specifically, Plot(function arguments) and Plot[n] =
(function arguments) will interpret your arguments as a
Cartesian plot if you provide one function and a parametric
plot if you provide two or three functions. Similarly, the
arguments will be interpreted as a curve if one variable is
used, and a surface if two are used.
Supported mode names by number of variables:
1: parametric, cartesian, polar
2: parametric, cartesian, cylindrical = polar, spherical
>>> Plot(1, mode='spherical')
Calculator-like Interface
=========================
>>> p = Plot(visible=False)
>>> f = x**2
>>> p[1] = f
>>> p[2] = f.diff(x)
>>> p[3] = f.diff(x).diff(x)
>>> p
[1]: x**2, 'mode=cartesian'
[2]: 2*x, 'mode=cartesian'
[3]: 2, 'mode=cartesian'
>>> p.show()
>>> p.clear()
>>> p
<blank plot>
>>> p[1] = x**2+y**2
>>> p[1].style = 'solid'
>>> p[2] = -x**2-y**2
>>> p[2].style = 'wireframe'
>>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4)
>>> p[1].style = 'both'
>>> p[2].style = 'both'
>>> p.close()
Plot Window Keyboard Controls
=============================
Screen Rotation:
X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2
Z axis Q,E, Numpad 7,9
Model Rotation:
Z axis Z,C, Numpad 1,3
Zoom: R,F, PgUp,PgDn, Numpad +,-
Reset Camera: X, Numpad 5
Camera Presets:
XY F1
XZ F2
YZ F3
Perspective F4
Sensitivity Modifier: SHIFT
Axes Toggle:
Visible F5
Colors F6
Close Window: ESCAPE
=============================
"""
@doctest_depends_on(modules=('pyglet',))
def __init__(self, *fargs, **win_args):
"""
Positional Arguments
====================
Any given positional arguments are used to
initialize a plot function at index 1. In
other words...
>>> from sympy.plotting.pygletplot import PygletPlot as Plot
>>> from sympy.abc import x
>>> p = Plot(x**2, visible=False)
...is equivalent to...
>>> p = Plot(visible=False)
>>> p[1] = x**2
Note that in earlier versions of the plotting
module, you were able to specify multiple
functions in the initializer. This functionality
has been dropped in favor of better automatic
plot plot_mode detection.
Named Arguments
===============
axes
An option string of the form
"key1=value1; key2 = value2" which
can use the following options:
style = ordinate
none OR frame OR box OR ordinate
stride = 0.25
val OR (val_x, val_y, val_z)
overlay = True (draw on top of plot)
True OR False
colored = False (False uses Black,
True uses colors
R,G,B = X,Y,Z)
True OR False
label_axes = False (display axis names
at endpoints)
True OR False
visible = True (show immediately
True OR False
The following named arguments are passed as
arguments to window initialization:
antialiasing = True
True OR False
ortho = False
True OR False
invert_mouse_zoom = False
True OR False
"""
# Register the plot modes
from . import plot_modes # noqa
self._win_args = win_args
self._window = None
self._render_lock = RLock()
self._functions = {}
self._pobjects = []
self._screenshot = ScreenShot(self)
axe_options = parse_option_string(win_args.pop('axes', ''))
self.axes = PlotAxes(**axe_options)
self._pobjects.append(self.axes)
self[0] = fargs
if win_args.get('visible', True):
self.show()
## Window Interfaces
def show(self):
"""
Creates and displays a plot window, or activates it
(gives it focus) if it has already been created.
"""
if self._window and not self._window.has_exit:
self._window.activate()
else:
self._win_args['visible'] = True
self.axes.reset_resources()
#if hasattr(self, '_doctest_depends_on'):
# self._win_args['runfromdoctester'] = True
self._window = PlotWindow(self, **self._win_args)
def close(self):
"""
Closes the plot window.
"""
if self._window:
self._window.close()
def saveimage(self, outfile=None, format='', size=(600, 500)):
"""
Saves a screen capture of the plot window to an
image file.
If outfile is given, it can either be a path
or a file object. Otherwise a png image will
be saved to the current working directory.
If the format is omitted, it is determined from
the filename extension.
"""
self._screenshot.save(outfile, format, size)
## Function List Interfaces
def clear(self):
"""
Clears the function list of this plot.
"""
self._render_lock.acquire()
self._functions = {}
self.adjust_all_bounds()
self._render_lock.release()
def __getitem__(self, i):
"""
Returns the function at position i in the
function list.
"""
return self._functions[i]
def __setitem__(self, i, args):
"""
Parses and adds a PlotMode to the function
list.
"""
if not (isinstance(i, (SYMPY_INTS, Integer)) and i >= 0):
raise ValueError("Function index must "
"be an integer >= 0.")
if isinstance(args, PlotObject):
f = args
else:
if (not is_sequence(args)) or isinstance(args, GeometryEntity):
args = [args]
if len(args) == 0:
return # no arguments given
kwargs = {"bounds_callback": self.adjust_all_bounds}
f = PlotMode(*args, **kwargs)
if f:
self._render_lock.acquire()
self._functions[i] = f
self._render_lock.release()
else:
raise ValueError("Failed to parse '%s'."
% ', '.join(str(a) for a in args))
def __delitem__(self, i):
"""
Removes the function in the function list at
position i.
"""
self._render_lock.acquire()
del self._functions[i]
self.adjust_all_bounds()
self._render_lock.release()
def firstavailableindex(self):
"""
Returns the first unused index in the function list.
"""
i = 0
self._render_lock.acquire()
while i in self._functions:
i += 1
self._render_lock.release()
return i
def append(self, *args):
"""
Parses and adds a PlotMode to the function
list at the first available index.
"""
self.__setitem__(self.firstavailableindex(), args)
def __len__(self):
"""
Returns the number of functions in the function list.
"""
return len(self._functions)
def __iter__(self):
"""
Allows iteration of the function list.
"""
return self._functions.itervalues()
def __repr__(self):
return str(self)
def __str__(self):
"""
Returns a string containing a new-line separated
list of the functions in the function list.
"""
s = ""
if len(self._functions) == 0:
s += "<blank plot>"
else:
self._render_lock.acquire()
s += "\n".join(["%s[%i]: %s" % ("", i, str(self._functions[i]))
for i in self._functions])
self._render_lock.release()
return s
def adjust_all_bounds(self):
self._render_lock.acquire()
self.axes.reset_bounding_box()
for f in self._functions:
self.axes.adjust_bounds(self._functions[f].bounds)
self._render_lock.release()
def wait_for_calculations(self):
sleep(0)
self._render_lock.acquire()
for f in self._functions:
a = self._functions[f]._get_calculating_verts
b = self._functions[f]._get_calculating_cverts
while a() or b():
sleep(0)
self._render_lock.release()
class ScreenShot:
def __init__(self, plot):
self._plot = plot
self.screenshot_requested = False
self.outfile = None
self.format = ''
self.invisibleMode = False
self.flag = 0
def __bool__(self):
return self.screenshot_requested
def _execute_saving(self):
if self.flag < 3:
self.flag += 1
return
size_x, size_y = self._plot._window.get_size()
size = size_x*size_y*4*ctypes.sizeof(ctypes.c_ubyte)
image = ctypes.create_string_buffer(size)
pgl.glReadPixels(0, 0, size_x, size_y, pgl.GL_RGBA, pgl.GL_UNSIGNED_BYTE, image)
from PIL import Image
im = Image.frombuffer('RGBA', (size_x, size_y),
image.raw, 'raw', 'RGBA', 0, 1)
im.transpose(Image.FLIP_TOP_BOTTOM).save(self.outfile, self.format)
self.flag = 0
self.screenshot_requested = False
if self.invisibleMode:
self._plot._window.close()
def save(self, outfile=None, format='', size=(600, 500)):
self.outfile = outfile
self.format = format
self.size = size
self.screenshot_requested = True
if not self._plot._window or self._plot._window.has_exit:
self._plot._win_args['visible'] = False
self._plot._win_args['width'] = size[0]
self._plot._win_args['height'] = size[1]
self._plot.axes.reset_resources()
self._plot._window = PlotWindow(self._plot, **self._plot._win_args)
self.invisibleMode = True
if self.outfile is None:
self.outfile = self._create_unique_path()
print(self.outfile)
def _create_unique_path(self):
cwd = getcwd()
l = listdir(cwd)
path = ''
i = 0
while True:
if not 'plot_%s.png' % i in l:
path = cwd + '/plot_%s.png' % i
break
i += 1
return path

View File

@ -0,0 +1,251 @@
import pyglet.gl as pgl
from pyglet import font
from sympy.core import S
from sympy.plotting.pygletplot.plot_object import PlotObject
from sympy.plotting.pygletplot.util import billboard_matrix, dot_product, \
get_direction_vectors, strided_range, vec_mag, vec_sub
from sympy.utilities.iterables import is_sequence
class PlotAxes(PlotObject):
def __init__(self, *args,
style='', none=None, frame=None, box=None, ordinate=None,
stride=0.25,
visible='', overlay='', colored='', label_axes='', label_ticks='',
tick_length=0.1,
font_face='Arial', font_size=28,
**kwargs):
# initialize style parameter
style = style.lower()
# allow alias kwargs to override style kwarg
if none is not None:
style = 'none'
if frame is not None:
style = 'frame'
if box is not None:
style = 'box'
if ordinate is not None:
style = 'ordinate'
if style in ['', 'ordinate']:
self._render_object = PlotAxesOrdinate(self)
elif style in ['frame', 'box']:
self._render_object = PlotAxesFrame(self)
elif style in ['none']:
self._render_object = None
else:
raise ValueError(("Unrecognized axes style %s.") % (style))
# initialize stride parameter
try:
stride = eval(stride)
except TypeError:
pass
if is_sequence(stride):
if len(stride) != 3:
raise ValueError("length should be equal to 3")
self._stride = stride
else:
self._stride = [stride, stride, stride]
self._tick_length = float(tick_length)
# setup bounding box and ticks
self._origin = [0, 0, 0]
self.reset_bounding_box()
def flexible_boolean(input, default):
if input in [True, False]:
return input
if input in ('f', 'F', 'false', 'False'):
return False
if input in ('t', 'T', 'true', 'True'):
return True
return default
# initialize remaining parameters
self.visible = flexible_boolean(kwargs, True)
self._overlay = flexible_boolean(overlay, True)
self._colored = flexible_boolean(colored, False)
self._label_axes = flexible_boolean(label_axes, False)
self._label_ticks = flexible_boolean(label_ticks, True)
# setup label font
self.font_face = font_face
self.font_size = font_size
# this is also used to reinit the
# font on window close/reopen
self.reset_resources()
def reset_resources(self):
self.label_font = None
def reset_bounding_box(self):
self._bounding_box = [[None, None], [None, None], [None, None]]
self._axis_ticks = [[], [], []]
def draw(self):
if self._render_object:
pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT | pgl.GL_DEPTH_BUFFER_BIT)
if self._overlay:
pgl.glDisable(pgl.GL_DEPTH_TEST)
self._render_object.draw()
pgl.glPopAttrib()
def adjust_bounds(self, child_bounds):
b = self._bounding_box
c = child_bounds
for i in range(3):
if abs(c[i][0]) is S.Infinity or abs(c[i][1]) is S.Infinity:
continue
b[i][0] = c[i][0] if b[i][0] is None else min([b[i][0], c[i][0]])
b[i][1] = c[i][1] if b[i][1] is None else max([b[i][1], c[i][1]])
self._bounding_box = b
self._recalculate_axis_ticks(i)
def _recalculate_axis_ticks(self, axis):
b = self._bounding_box
if b[axis][0] is None or b[axis][1] is None:
self._axis_ticks[axis] = []
else:
self._axis_ticks[axis] = strided_range(b[axis][0], b[axis][1],
self._stride[axis])
def toggle_visible(self):
self.visible = not self.visible
def toggle_colors(self):
self._colored = not self._colored
class PlotAxesBase(PlotObject):
def __init__(self, parent_axes):
self._p = parent_axes
def draw(self):
color = [([0.2, 0.1, 0.3], [0.2, 0.1, 0.3], [0.2, 0.1, 0.3]),
([0.9, 0.3, 0.5], [0.5, 1.0, 0.5], [0.3, 0.3, 0.9])][self._p._colored]
self.draw_background(color)
self.draw_axis(2, color[2])
self.draw_axis(1, color[1])
self.draw_axis(0, color[0])
def draw_background(self, color):
pass # optional
def draw_axis(self, axis, color):
raise NotImplementedError()
def draw_text(self, text, position, color, scale=1.0):
if len(color) == 3:
color = (color[0], color[1], color[2], 1.0)
if self._p.label_font is None:
self._p.label_font = font.load(self._p.font_face,
self._p.font_size,
bold=True, italic=False)
label = font.Text(self._p.label_font, text,
color=color,
valign=font.Text.BASELINE,
halign=font.Text.CENTER)
pgl.glPushMatrix()
pgl.glTranslatef(*position)
billboard_matrix()
scale_factor = 0.005 * scale
pgl.glScalef(scale_factor, scale_factor, scale_factor)
pgl.glColor4f(0, 0, 0, 0)
label.draw()
pgl.glPopMatrix()
def draw_line(self, v, color):
o = self._p._origin
pgl.glBegin(pgl.GL_LINES)
pgl.glColor3f(*color)
pgl.glVertex3f(v[0][0] + o[0], v[0][1] + o[1], v[0][2] + o[2])
pgl.glVertex3f(v[1][0] + o[0], v[1][1] + o[1], v[1][2] + o[2])
pgl.glEnd()
class PlotAxesOrdinate(PlotAxesBase):
def __init__(self, parent_axes):
super().__init__(parent_axes)
def draw_axis(self, axis, color):
ticks = self._p._axis_ticks[axis]
radius = self._p._tick_length / 2.0
if len(ticks) < 2:
return
# calculate the vector for this axis
axis_lines = [[0, 0, 0], [0, 0, 0]]
axis_lines[0][axis], axis_lines[1][axis] = ticks[0], ticks[-1]
axis_vector = vec_sub(axis_lines[1], axis_lines[0])
# calculate angle to the z direction vector
pos_z = get_direction_vectors()[2]
d = abs(dot_product(axis_vector, pos_z))
d = d / vec_mag(axis_vector)
# don't draw labels if we're looking down the axis
labels_visible = abs(d - 1.0) > 0.02
# draw the ticks and labels
for tick in ticks:
self.draw_tick_line(axis, color, radius, tick, labels_visible)
# draw the axis line and labels
self.draw_axis_line(axis, color, ticks[0], ticks[-1], labels_visible)
def draw_axis_line(self, axis, color, a_min, a_max, labels_visible):
axis_line = [[0, 0, 0], [0, 0, 0]]
axis_line[0][axis], axis_line[1][axis] = a_min, a_max
self.draw_line(axis_line, color)
if labels_visible:
self.draw_axis_line_labels(axis, color, axis_line)
def draw_axis_line_labels(self, axis, color, axis_line):
if not self._p._label_axes:
return
axis_labels = [axis_line[0][::], axis_line[1][::]]
axis_labels[0][axis] -= 0.3
axis_labels[1][axis] += 0.3
a_str = ['X', 'Y', 'Z'][axis]
self.draw_text("-" + a_str, axis_labels[0], color)
self.draw_text("+" + a_str, axis_labels[1], color)
def draw_tick_line(self, axis, color, radius, tick, labels_visible):
tick_axis = {0: 1, 1: 0, 2: 1}[axis]
tick_line = [[0, 0, 0], [0, 0, 0]]
tick_line[0][axis] = tick_line[1][axis] = tick
tick_line[0][tick_axis], tick_line[1][tick_axis] = -radius, radius
self.draw_line(tick_line, color)
if labels_visible:
self.draw_tick_line_label(axis, color, radius, tick)
def draw_tick_line_label(self, axis, color, radius, tick):
if not self._p._label_axes:
return
tick_label_vector = [0, 0, 0]
tick_label_vector[axis] = tick
tick_label_vector[{0: 1, 1: 0, 2: 1}[axis]] = [-1, 1, 1][
axis] * radius * 3.5
self.draw_text(str(tick), tick_label_vector, color, scale=0.5)
class PlotAxesFrame(PlotAxesBase):
def __init__(self, parent_axes):
super().__init__(parent_axes)
def draw_background(self, color):
pass
def draw_axis(self, axis, color):
raise NotImplementedError()

View File

@ -0,0 +1,124 @@
import pyglet.gl as pgl
from sympy.plotting.pygletplot.plot_rotation import get_spherical_rotatation
from sympy.plotting.pygletplot.util import get_model_matrix, model_to_screen, \
screen_to_model, vec_subs
class PlotCamera:
min_dist = 0.05
max_dist = 500.0
min_ortho_dist = 100.0
max_ortho_dist = 10000.0
_default_dist = 6.0
_default_ortho_dist = 600.0
rot_presets = {
'xy': (0, 0, 0),
'xz': (-90, 0, 0),
'yz': (0, 90, 0),
'perspective': (-45, 0, -45)
}
def __init__(self, window, ortho=False):
self.window = window
self.axes = self.window.plot.axes
self.ortho = ortho
self.reset()
def init_rot_matrix(self):
pgl.glPushMatrix()
pgl.glLoadIdentity()
self._rot = get_model_matrix()
pgl.glPopMatrix()
def set_rot_preset(self, preset_name):
self.init_rot_matrix()
if preset_name not in self.rot_presets:
raise ValueError(
"%s is not a valid rotation preset." % preset_name)
r = self.rot_presets[preset_name]
self.euler_rotate(r[0], 1, 0, 0)
self.euler_rotate(r[1], 0, 1, 0)
self.euler_rotate(r[2], 0, 0, 1)
def reset(self):
self._dist = 0.0
self._x, self._y = 0.0, 0.0
self._rot = None
if self.ortho:
self._dist = self._default_ortho_dist
else:
self._dist = self._default_dist
self.init_rot_matrix()
def mult_rot_matrix(self, rot):
pgl.glPushMatrix()
pgl.glLoadMatrixf(rot)
pgl.glMultMatrixf(self._rot)
self._rot = get_model_matrix()
pgl.glPopMatrix()
def setup_projection(self):
pgl.glMatrixMode(pgl.GL_PROJECTION)
pgl.glLoadIdentity()
if self.ortho:
# yep, this is pseudo ortho (don't tell anyone)
pgl.gluPerspective(
0.3, float(self.window.width)/float(self.window.height),
self.min_ortho_dist - 0.01, self.max_ortho_dist + 0.01)
else:
pgl.gluPerspective(
30.0, float(self.window.width)/float(self.window.height),
self.min_dist - 0.01, self.max_dist + 0.01)
pgl.glMatrixMode(pgl.GL_MODELVIEW)
def _get_scale(self):
return 1.0, 1.0, 1.0
def apply_transformation(self):
pgl.glLoadIdentity()
pgl.glTranslatef(self._x, self._y, -self._dist)
if self._rot is not None:
pgl.glMultMatrixf(self._rot)
pgl.glScalef(*self._get_scale())
def spherical_rotate(self, p1, p2, sensitivity=1.0):
mat = get_spherical_rotatation(p1, p2, self.window.width,
self.window.height, sensitivity)
if mat is not None:
self.mult_rot_matrix(mat)
def euler_rotate(self, angle, x, y, z):
pgl.glPushMatrix()
pgl.glLoadMatrixf(self._rot)
pgl.glRotatef(angle, x, y, z)
self._rot = get_model_matrix()
pgl.glPopMatrix()
def zoom_relative(self, clicks, sensitivity):
if self.ortho:
dist_d = clicks * sensitivity * 50.0
min_dist = self.min_ortho_dist
max_dist = self.max_ortho_dist
else:
dist_d = clicks * sensitivity
min_dist = self.min_dist
max_dist = self.max_dist
new_dist = (self._dist - dist_d)
if (clicks < 0 and new_dist < max_dist) or new_dist > min_dist:
self._dist = new_dist
def mouse_translate(self, x, y, dx, dy):
pgl.glPushMatrix()
pgl.glLoadIdentity()
pgl.glTranslatef(0, 0, -self._dist)
z = model_to_screen(0, 0, 0)[2]
d = vec_subs(screen_to_model(x, y, z), screen_to_model(x - dx, y - dy, z))
pgl.glPopMatrix()
self._x += d[0]
self._y += d[1]

View File

@ -0,0 +1,218 @@
from pyglet.window import key
from pyglet.window.mouse import LEFT, RIGHT, MIDDLE
from sympy.plotting.pygletplot.util import get_direction_vectors, get_basis_vectors
class PlotController:
normal_mouse_sensitivity = 4.0
modified_mouse_sensitivity = 1.0
normal_key_sensitivity = 160.0
modified_key_sensitivity = 40.0
keymap = {
key.LEFT: 'left',
key.A: 'left',
key.NUM_4: 'left',
key.RIGHT: 'right',
key.D: 'right',
key.NUM_6: 'right',
key.UP: 'up',
key.W: 'up',
key.NUM_8: 'up',
key.DOWN: 'down',
key.S: 'down',
key.NUM_2: 'down',
key.Z: 'rotate_z_neg',
key.NUM_1: 'rotate_z_neg',
key.C: 'rotate_z_pos',
key.NUM_3: 'rotate_z_pos',
key.Q: 'spin_left',
key.NUM_7: 'spin_left',
key.E: 'spin_right',
key.NUM_9: 'spin_right',
key.X: 'reset_camera',
key.NUM_5: 'reset_camera',
key.NUM_ADD: 'zoom_in',
key.PAGEUP: 'zoom_in',
key.R: 'zoom_in',
key.NUM_SUBTRACT: 'zoom_out',
key.PAGEDOWN: 'zoom_out',
key.F: 'zoom_out',
key.RSHIFT: 'modify_sensitivity',
key.LSHIFT: 'modify_sensitivity',
key.F1: 'rot_preset_xy',
key.F2: 'rot_preset_xz',
key.F3: 'rot_preset_yz',
key.F4: 'rot_preset_perspective',
key.F5: 'toggle_axes',
key.F6: 'toggle_axe_colors',
key.F8: 'save_image'
}
def __init__(self, window, *, invert_mouse_zoom=False, **kwargs):
self.invert_mouse_zoom = invert_mouse_zoom
self.window = window
self.camera = window.camera
self.action = {
# Rotation around the view Y (up) vector
'left': False,
'right': False,
# Rotation around the view X vector
'up': False,
'down': False,
# Rotation around the view Z vector
'spin_left': False,
'spin_right': False,
# Rotation around the model Z vector
'rotate_z_neg': False,
'rotate_z_pos': False,
# Reset to the default rotation
'reset_camera': False,
# Performs camera z-translation
'zoom_in': False,
'zoom_out': False,
# Use alternative sensitivity (speed)
'modify_sensitivity': False,
# Rotation presets
'rot_preset_xy': False,
'rot_preset_xz': False,
'rot_preset_yz': False,
'rot_preset_perspective': False,
# axes
'toggle_axes': False,
'toggle_axe_colors': False,
# screenshot
'save_image': False
}
def update(self, dt):
z = 0
if self.action['zoom_out']:
z -= 1
if self.action['zoom_in']:
z += 1
if z != 0:
self.camera.zoom_relative(z/10.0, self.get_key_sensitivity()/10.0)
dx, dy, dz = 0, 0, 0
if self.action['left']:
dx -= 1
if self.action['right']:
dx += 1
if self.action['up']:
dy -= 1
if self.action['down']:
dy += 1
if self.action['spin_left']:
dz += 1
if self.action['spin_right']:
dz -= 1
if not self.is_2D():
if dx != 0:
self.camera.euler_rotate(dx*dt*self.get_key_sensitivity(),
*(get_direction_vectors()[1]))
if dy != 0:
self.camera.euler_rotate(dy*dt*self.get_key_sensitivity(),
*(get_direction_vectors()[0]))
if dz != 0:
self.camera.euler_rotate(dz*dt*self.get_key_sensitivity(),
*(get_direction_vectors()[2]))
else:
self.camera.mouse_translate(0, 0, dx*dt*self.get_key_sensitivity(),
-dy*dt*self.get_key_sensitivity())
rz = 0
if self.action['rotate_z_neg'] and not self.is_2D():
rz -= 1
if self.action['rotate_z_pos'] and not self.is_2D():
rz += 1
if rz != 0:
self.camera.euler_rotate(rz*dt*self.get_key_sensitivity(),
*(get_basis_vectors()[2]))
if self.action['reset_camera']:
self.camera.reset()
if self.action['rot_preset_xy']:
self.camera.set_rot_preset('xy')
if self.action['rot_preset_xz']:
self.camera.set_rot_preset('xz')
if self.action['rot_preset_yz']:
self.camera.set_rot_preset('yz')
if self.action['rot_preset_perspective']:
self.camera.set_rot_preset('perspective')
if self.action['toggle_axes']:
self.action['toggle_axes'] = False
self.camera.axes.toggle_visible()
if self.action['toggle_axe_colors']:
self.action['toggle_axe_colors'] = False
self.camera.axes.toggle_colors()
if self.action['save_image']:
self.action['save_image'] = False
self.window.plot.saveimage()
return True
def get_mouse_sensitivity(self):
if self.action['modify_sensitivity']:
return self.modified_mouse_sensitivity
else:
return self.normal_mouse_sensitivity
def get_key_sensitivity(self):
if self.action['modify_sensitivity']:
return self.modified_key_sensitivity
else:
return self.normal_key_sensitivity
def on_key_press(self, symbol, modifiers):
if symbol in self.keymap:
self.action[self.keymap[symbol]] = True
def on_key_release(self, symbol, modifiers):
if symbol in self.keymap:
self.action[self.keymap[symbol]] = False
def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers):
if buttons & LEFT:
if self.is_2D():
self.camera.mouse_translate(x, y, dx, dy)
else:
self.camera.spherical_rotate((x - dx, y - dy), (x, y),
self.get_mouse_sensitivity())
if buttons & MIDDLE:
self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy,
self.get_mouse_sensitivity()/20.0)
if buttons & RIGHT:
self.camera.mouse_translate(x, y, dx, dy)
def on_mouse_scroll(self, x, y, dx, dy):
self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy,
self.get_mouse_sensitivity())
def is_2D(self):
functions = self.window.plot._functions
for i in functions:
if len(functions[i].i_vars) > 1 or len(functions[i].d_vars) > 2:
return False
return True

View File

@ -0,0 +1,82 @@
import pyglet.gl as pgl
from sympy.core import S
from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase
class PlotCurve(PlotModeBase):
style_override = 'wireframe'
def _on_calculate_verts(self):
self.t_interval = self.intervals[0]
self.t_set = list(self.t_interval.frange())
self.bounds = [[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0]]
evaluate = self._get_evaluator()
self._calculating_verts_pos = 0.0
self._calculating_verts_len = float(self.t_interval.v_len)
self.verts = []
b = self.bounds
for t in self.t_set:
try:
_e = evaluate(t) # calculate vertex
except (NameError, ZeroDivisionError):
_e = None
if _e is not None: # update bounding box
for axis in range(3):
b[axis][0] = min([b[axis][0], _e[axis]])
b[axis][1] = max([b[axis][1], _e[axis]])
self.verts.append(_e)
self._calculating_verts_pos += 1.0
for axis in range(3):
b[axis][2] = b[axis][1] - b[axis][0]
if b[axis][2] == 0.0:
b[axis][2] = 1.0
self.push_wireframe(self.draw_verts(False))
def _on_calculate_cverts(self):
if not self.verts or not self.color:
return
def set_work_len(n):
self._calculating_cverts_len = float(n)
def inc_work_pos():
self._calculating_cverts_pos += 1.0
set_work_len(1)
self._calculating_cverts_pos = 0
self.cverts = self.color.apply_to_curve(self.verts,
self.t_set,
set_len=set_work_len,
inc_pos=inc_work_pos)
self.push_wireframe(self.draw_verts(True))
def calculate_one_cvert(self, t):
vert = self.verts[t]
return self.color(vert[0], vert[1], vert[2],
self.t_set[t], None)
def draw_verts(self, use_cverts):
def f():
pgl.glBegin(pgl.GL_LINE_STRIP)
for t in range(len(self.t_set)):
p = self.verts[t]
if p is None:
pgl.glEnd()
pgl.glBegin(pgl.GL_LINE_STRIP)
continue
if use_cverts:
c = self.cverts[t]
if c is None:
c = (0, 0, 0)
pgl.glColor3f(*c)
else:
pgl.glColor3f(*self.default_wireframe_color)
pgl.glVertex3f(*p)
pgl.glEnd()
return f

View File

@ -0,0 +1,181 @@
from sympy.core.singleton import S
from sympy.core.symbol import Symbol
from sympy.core.sympify import sympify
from sympy.core.numbers import Integer
class PlotInterval:
"""
"""
_v, _v_min, _v_max, _v_steps = None, None, None, None
def require_all_args(f):
def check(self, *args, **kwargs):
for g in [self._v, self._v_min, self._v_max, self._v_steps]:
if g is None:
raise ValueError("PlotInterval is incomplete.")
return f(self, *args, **kwargs)
return check
def __init__(self, *args):
if len(args) == 1:
if isinstance(args[0], PlotInterval):
self.fill_from(args[0])
return
elif isinstance(args[0], str):
try:
args = eval(args[0])
except TypeError:
s_eval_error = "Could not interpret string %s."
raise ValueError(s_eval_error % (args[0]))
elif isinstance(args[0], (tuple, list)):
args = args[0]
else:
raise ValueError("Not an interval.")
if not isinstance(args, (tuple, list)) or len(args) > 4:
f_error = "PlotInterval must be a tuple or list of length 4 or less."
raise ValueError(f_error)
args = list(args)
if len(args) > 0 and (args[0] is None or isinstance(args[0], Symbol)):
self.v = args.pop(0)
if len(args) in [2, 3]:
self.v_min = args.pop(0)
self.v_max = args.pop(0)
if len(args) == 1:
self.v_steps = args.pop(0)
elif len(args) == 1:
self.v_steps = args.pop(0)
def get_v(self):
return self._v
def set_v(self, v):
if v is None:
self._v = None
return
if not isinstance(v, Symbol):
raise ValueError("v must be a SymPy Symbol.")
self._v = v
def get_v_min(self):
return self._v_min
def set_v_min(self, v_min):
if v_min is None:
self._v_min = None
return
try:
self._v_min = sympify(v_min)
float(self._v_min.evalf())
except TypeError:
raise ValueError("v_min could not be interpreted as a number.")
def get_v_max(self):
return self._v_max
def set_v_max(self, v_max):
if v_max is None:
self._v_max = None
return
try:
self._v_max = sympify(v_max)
float(self._v_max.evalf())
except TypeError:
raise ValueError("v_max could not be interpreted as a number.")
def get_v_steps(self):
return self._v_steps
def set_v_steps(self, v_steps):
if v_steps is None:
self._v_steps = None
return
if isinstance(v_steps, int):
v_steps = Integer(v_steps)
elif not isinstance(v_steps, Integer):
raise ValueError("v_steps must be an int or SymPy Integer.")
if v_steps <= S.Zero:
raise ValueError("v_steps must be positive.")
self._v_steps = v_steps
@require_all_args
def get_v_len(self):
return self.v_steps + 1
v = property(get_v, set_v)
v_min = property(get_v_min, set_v_min)
v_max = property(get_v_max, set_v_max)
v_steps = property(get_v_steps, set_v_steps)
v_len = property(get_v_len)
def fill_from(self, b):
if b.v is not None:
self.v = b.v
if b.v_min is not None:
self.v_min = b.v_min
if b.v_max is not None:
self.v_max = b.v_max
if b.v_steps is not None:
self.v_steps = b.v_steps
@staticmethod
def try_parse(*args):
"""
Returns a PlotInterval if args can be interpreted
as such, otherwise None.
"""
if len(args) == 1 and isinstance(args[0], PlotInterval):
return args[0]
try:
return PlotInterval(*args)
except ValueError:
return None
def _str_base(self):
return ",".join([str(self.v), str(self.v_min),
str(self.v_max), str(self.v_steps)])
def __repr__(self):
"""
A string representing the interval in class constructor form.
"""
return "PlotInterval(%s)" % (self._str_base())
def __str__(self):
"""
A string representing the interval in list form.
"""
return "[%s]" % (self._str_base())
@require_all_args
def assert_complete(self):
pass
@require_all_args
def vrange(self):
"""
Yields v_steps+1 SymPy numbers ranging from
v_min to v_max.
"""
d = (self.v_max - self.v_min) / self.v_steps
for i in range(self.v_steps + 1):
a = self.v_min + (d * Integer(i))
yield a
@require_all_args
def vrange2(self):
"""
Yields v_steps pairs of SymPy numbers ranging from
(v_min, v_min + step) to (v_max - step, v_max).
"""
d = (self.v_max - self.v_min) / self.v_steps
a = self.v_min + (d * S.Zero)
for i in range(self.v_steps):
b = self.v_min + (d * Integer(i + 1))
yield a, b
a = b
def frange(self):
for i in self.vrange():
yield float(i.evalf())

View File

@ -0,0 +1,400 @@
from .plot_interval import PlotInterval
from .plot_object import PlotObject
from .util import parse_option_string
from sympy.core.symbol import Symbol
from sympy.core.sympify import sympify
from sympy.geometry.entity import GeometryEntity
from sympy.utilities.iterables import is_sequence
class PlotMode(PlotObject):
"""
Grandparent class for plotting
modes. Serves as interface for
registration, lookup, and init
of modes.
To create a new plot mode,
inherit from PlotModeBase
or one of its children, such
as PlotSurface or PlotCurve.
"""
## Class-level attributes
## used to register and lookup
## plot modes. See PlotModeBase
## for descriptions and usage.
i_vars, d_vars = '', ''
intervals = []
aliases = []
is_default = False
## Draw is the only method here which
## is meant to be overridden in child
## classes, and PlotModeBase provides
## a base implementation.
def draw(self):
raise NotImplementedError()
## Everything else in this file has to
## do with registration and retrieval
## of plot modes. This is where I've
## hidden much of the ugliness of automatic
## plot mode divination...
## Plot mode registry data structures
_mode_alias_list = []
_mode_map = {
1: {1: {}, 2: {}},
2: {1: {}, 2: {}},
3: {1: {}, 2: {}},
} # [d][i][alias_str]: class
_mode_default_map = {
1: {},
2: {},
3: {},
} # [d][i]: class
_i_var_max, _d_var_max = 2, 3
def __new__(cls, *args, **kwargs):
"""
This is the function which interprets
arguments given to Plot.__init__ and
Plot.__setattr__. Returns an initialized
instance of the appropriate child class.
"""
newargs, newkwargs = PlotMode._extract_options(args, kwargs)
mode_arg = newkwargs.get('mode', '')
# Interpret the arguments
d_vars, intervals = PlotMode._interpret_args(newargs)
i_vars = PlotMode._find_i_vars(d_vars, intervals)
i, d = max([len(i_vars), len(intervals)]), len(d_vars)
# Find the appropriate mode
subcls = PlotMode._get_mode(mode_arg, i, d)
# Create the object
o = object.__new__(subcls)
# Do some setup for the mode instance
o.d_vars = d_vars
o._fill_i_vars(i_vars)
o._fill_intervals(intervals)
o.options = newkwargs
return o
@staticmethod
def _get_mode(mode_arg, i_var_count, d_var_count):
"""
Tries to return an appropriate mode class.
Intended to be called only by __new__.
mode_arg
Can be a string or a class. If it is a
PlotMode subclass, it is simply returned.
If it is a string, it can an alias for
a mode or an empty string. In the latter
case, we try to find a default mode for
the i_var_count and d_var_count.
i_var_count
The number of independent variables
needed to evaluate the d_vars.
d_var_count
The number of dependent variables;
usually the number of functions to
be evaluated in plotting.
For example, a Cartesian function y = f(x) has
one i_var (x) and one d_var (y). A parametric
form x,y,z = f(u,v), f(u,v), f(u,v) has two
two i_vars (u,v) and three d_vars (x,y,z).
"""
# if the mode_arg is simply a PlotMode class,
# check that the mode supports the numbers
# of independent and dependent vars, then
# return it
try:
m = None
if issubclass(mode_arg, PlotMode):
m = mode_arg
except TypeError:
pass
if m:
if not m._was_initialized:
raise ValueError(("To use unregistered plot mode %s "
"you must first call %s._init_mode().")
% (m.__name__, m.__name__))
if d_var_count != m.d_var_count:
raise ValueError(("%s can only plot functions "
"with %i dependent variables.")
% (m.__name__,
m.d_var_count))
if i_var_count > m.i_var_count:
raise ValueError(("%s cannot plot functions "
"with more than %i independent "
"variables.")
% (m.__name__,
m.i_var_count))
return m
# If it is a string, there are two possibilities.
if isinstance(mode_arg, str):
i, d = i_var_count, d_var_count
if i > PlotMode._i_var_max:
raise ValueError(var_count_error(True, True))
if d > PlotMode._d_var_max:
raise ValueError(var_count_error(False, True))
# If the string is '', try to find a suitable
# default mode
if not mode_arg:
return PlotMode._get_default_mode(i, d)
# Otherwise, interpret the string as a mode
# alias (e.g. 'cartesian', 'parametric', etc)
else:
return PlotMode._get_aliased_mode(mode_arg, i, d)
else:
raise ValueError("PlotMode argument must be "
"a class or a string")
@staticmethod
def _get_default_mode(i, d, i_vars=-1):
if i_vars == -1:
i_vars = i
try:
return PlotMode._mode_default_map[d][i]
except KeyError:
# Keep looking for modes in higher i var counts
# which support the given d var count until we
# reach the max i_var count.
if i < PlotMode._i_var_max:
return PlotMode._get_default_mode(i + 1, d, i_vars)
else:
raise ValueError(("Couldn't find a default mode "
"for %i independent and %i "
"dependent variables.") % (i_vars, d))
@staticmethod
def _get_aliased_mode(alias, i, d, i_vars=-1):
if i_vars == -1:
i_vars = i
if alias not in PlotMode._mode_alias_list:
raise ValueError(("Couldn't find a mode called"
" %s. Known modes: %s.")
% (alias, ", ".join(PlotMode._mode_alias_list)))
try:
return PlotMode._mode_map[d][i][alias]
except TypeError:
# Keep looking for modes in higher i var counts
# which support the given d var count and alias
# until we reach the max i_var count.
if i < PlotMode._i_var_max:
return PlotMode._get_aliased_mode(alias, i + 1, d, i_vars)
else:
raise ValueError(("Couldn't find a %s mode "
"for %i independent and %i "
"dependent variables.")
% (alias, i_vars, d))
@classmethod
def _register(cls):
"""
Called once for each user-usable plot mode.
For Cartesian2D, it is invoked after the
class definition: Cartesian2D._register()
"""
name = cls.__name__
cls._init_mode()
try:
i, d = cls.i_var_count, cls.d_var_count
# Add the mode to _mode_map under all
# given aliases
for a in cls.aliases:
if a not in PlotMode._mode_alias_list:
# Also track valid aliases, so
# we can quickly know when given
# an invalid one in _get_mode.
PlotMode._mode_alias_list.append(a)
PlotMode._mode_map[d][i][a] = cls
if cls.is_default:
# If this mode was marked as the
# default for this d,i combination,
# also set that.
PlotMode._mode_default_map[d][i] = cls
except Exception as e:
raise RuntimeError(("Failed to register "
"plot mode %s. Reason: %s")
% (name, (str(e))))
@classmethod
def _init_mode(cls):
"""
Initializes the plot mode based on
the 'mode-specific parameters' above.
Only intended to be called by
PlotMode._register(). To use a mode without
registering it, you can directly call
ModeSubclass._init_mode().
"""
def symbols_list(symbol_str):
return [Symbol(s) for s in symbol_str]
# Convert the vars strs into
# lists of symbols.
cls.i_vars = symbols_list(cls.i_vars)
cls.d_vars = symbols_list(cls.d_vars)
# Var count is used often, calculate
# it once here
cls.i_var_count = len(cls.i_vars)
cls.d_var_count = len(cls.d_vars)
if cls.i_var_count > PlotMode._i_var_max:
raise ValueError(var_count_error(True, False))
if cls.d_var_count > PlotMode._d_var_max:
raise ValueError(var_count_error(False, False))
# Try to use first alias as primary_alias
if len(cls.aliases) > 0:
cls.primary_alias = cls.aliases[0]
else:
cls.primary_alias = cls.__name__
di = cls.intervals
if len(di) != cls.i_var_count:
raise ValueError("Plot mode must provide a "
"default interval for each i_var.")
for i in range(cls.i_var_count):
# default intervals must be given [min,max,steps]
# (no var, but they must be in the same order as i_vars)
if len(di[i]) != 3:
raise ValueError("length should be equal to 3")
# Initialize an incomplete interval,
# to later be filled with a var when
# the mode is instantiated.
di[i] = PlotInterval(None, *di[i])
# To prevent people from using modes
# without these required fields set up.
cls._was_initialized = True
_was_initialized = False
## Initializer Helper Methods
@staticmethod
def _find_i_vars(functions, intervals):
i_vars = []
# First, collect i_vars in the
# order they are given in any
# intervals.
for i in intervals:
if i.v is None:
continue
elif i.v in i_vars:
raise ValueError(("Multiple intervals given "
"for %s.") % (str(i.v)))
i_vars.append(i.v)
# Then, find any remaining
# i_vars in given functions
# (aka d_vars)
for f in functions:
for a in f.free_symbols:
if a not in i_vars:
i_vars.append(a)
return i_vars
def _fill_i_vars(self, i_vars):
# copy default i_vars
self.i_vars = [Symbol(str(i)) for i in self.i_vars]
# replace with given i_vars
for i in range(len(i_vars)):
self.i_vars[i] = i_vars[i]
def _fill_intervals(self, intervals):
# copy default intervals
self.intervals = [PlotInterval(i) for i in self.intervals]
# track i_vars used so far
v_used = []
# fill copy of default
# intervals with given info
for i in range(len(intervals)):
self.intervals[i].fill_from(intervals[i])
if self.intervals[i].v is not None:
v_used.append(self.intervals[i].v)
# Find any orphan intervals and
# assign them i_vars
for i in range(len(self.intervals)):
if self.intervals[i].v is None:
u = [v for v in self.i_vars if v not in v_used]
if len(u) == 0:
raise ValueError("length should not be equal to 0")
self.intervals[i].v = u[0]
v_used.append(u[0])
@staticmethod
def _interpret_args(args):
interval_wrong_order = "PlotInterval %s was given before any function(s)."
interpret_error = "Could not interpret %s as a function or interval."
functions, intervals = [], []
if isinstance(args[0], GeometryEntity):
for coords in list(args[0].arbitrary_point()):
functions.append(coords)
intervals.append(PlotInterval.try_parse(args[0].plot_interval()))
else:
for a in args:
i = PlotInterval.try_parse(a)
if i is not None:
if len(functions) == 0:
raise ValueError(interval_wrong_order % (str(i)))
else:
intervals.append(i)
else:
if is_sequence(a, include=str):
raise ValueError(interpret_error % (str(a)))
try:
f = sympify(a)
functions.append(f)
except TypeError:
raise ValueError(interpret_error % str(a))
return functions, intervals
@staticmethod
def _extract_options(args, kwargs):
newkwargs, newargs = {}, []
for a in args:
if isinstance(a, str):
newkwargs = dict(newkwargs, **parse_option_string(a))
else:
newargs.append(a)
newkwargs = dict(newkwargs, **kwargs)
return newargs, newkwargs
def var_count_error(is_independent, is_plotting):
"""
Used to format an error message which differs
slightly in 4 places.
"""
if is_plotting:
v = "Plotting"
else:
v = "Registering plot modes"
if is_independent:
n, s = PlotMode._i_var_max, "independent"
else:
n, s = PlotMode._d_var_max, "dependent"
return ("%s with more than %i %s variables "
"is not supported.") % (v, n, s)

View File

@ -0,0 +1,378 @@
import pyglet.gl as pgl
from sympy.core import S
from sympy.plotting.pygletplot.color_scheme import ColorScheme
from sympy.plotting.pygletplot.plot_mode import PlotMode
from sympy.utilities.iterables import is_sequence
from time import sleep
from threading import Thread, Event, RLock
import warnings
class PlotModeBase(PlotMode):
"""
Intended parent class for plotting
modes. Provides base functionality
in conjunction with its parent,
PlotMode.
"""
##
## Class-Level Attributes
##
"""
The following attributes are meant
to be set at the class level, and serve
as parameters to the plot mode registry
(in PlotMode). See plot_modes.py for
concrete examples.
"""
"""
i_vars
'x' for Cartesian2D
'xy' for Cartesian3D
etc.
d_vars
'y' for Cartesian2D
'r' for Polar
etc.
"""
i_vars, d_vars = '', ''
"""
intervals
Default intervals for each i_var, and in the
same order. Specified [min, max, steps].
No variable can be given (it is bound later).
"""
intervals = []
"""
aliases
A list of strings which can be used to
access this mode.
'cartesian' for Cartesian2D and Cartesian3D
'polar' for Polar
'cylindrical', 'polar' for Cylindrical
Note that _init_mode chooses the first alias
in the list as the mode's primary_alias, which
will be displayed to the end user in certain
contexts.
"""
aliases = []
"""
is_default
Whether to set this mode as the default
for arguments passed to PlotMode() containing
the same number of d_vars as this mode and
at most the same number of i_vars.
"""
is_default = False
"""
All of the above attributes are defined in PlotMode.
The following ones are specific to PlotModeBase.
"""
"""
A list of the render styles. Do not modify.
"""
styles = {'wireframe': 1, 'solid': 2, 'both': 3}
"""
style_override
Always use this style if not blank.
"""
style_override = ''
"""
default_wireframe_color
default_solid_color
Can be used when color is None or being calculated.
Used by PlotCurve and PlotSurface, but not anywhere
in PlotModeBase.
"""
default_wireframe_color = (0.85, 0.85, 0.85)
default_solid_color = (0.6, 0.6, 0.9)
default_rot_preset = 'xy'
##
## Instance-Level Attributes
##
## 'Abstract' member functions
def _get_evaluator(self):
if self.use_lambda_eval:
try:
e = self._get_lambda_evaluator()
return e
except Exception:
warnings.warn("\nWarning: creating lambda evaluator failed. "
"Falling back on SymPy subs evaluator.")
return self._get_sympy_evaluator()
def _get_sympy_evaluator(self):
raise NotImplementedError()
def _get_lambda_evaluator(self):
raise NotImplementedError()
def _on_calculate_verts(self):
raise NotImplementedError()
def _on_calculate_cverts(self):
raise NotImplementedError()
## Base member functions
def __init__(self, *args, bounds_callback=None, **kwargs):
self.verts = []
self.cverts = []
self.bounds = [[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0]]
self.cbounds = [[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0]]
self._draw_lock = RLock()
self._calculating_verts = Event()
self._calculating_cverts = Event()
self._calculating_verts_pos = 0.0
self._calculating_verts_len = 0.0
self._calculating_cverts_pos = 0.0
self._calculating_cverts_len = 0.0
self._max_render_stack_size = 3
self._draw_wireframe = [-1]
self._draw_solid = [-1]
self._style = None
self._color = None
self.predraw = []
self.postdraw = []
self.use_lambda_eval = self.options.pop('use_sympy_eval', None) is None
self.style = self.options.pop('style', '')
self.color = self.options.pop('color', 'rainbow')
self.bounds_callback = bounds_callback
self._on_calculate()
def synchronized(f):
def w(self, *args, **kwargs):
self._draw_lock.acquire()
try:
r = f(self, *args, **kwargs)
return r
finally:
self._draw_lock.release()
return w
@synchronized
def push_wireframe(self, function):
"""
Push a function which performs gl commands
used to build a display list. (The list is
built outside of the function)
"""
assert callable(function)
self._draw_wireframe.append(function)
if len(self._draw_wireframe) > self._max_render_stack_size:
del self._draw_wireframe[1] # leave marker element
@synchronized
def push_solid(self, function):
"""
Push a function which performs gl commands
used to build a display list. (The list is
built outside of the function)
"""
assert callable(function)
self._draw_solid.append(function)
if len(self._draw_solid) > self._max_render_stack_size:
del self._draw_solid[1] # leave marker element
def _create_display_list(self, function):
dl = pgl.glGenLists(1)
pgl.glNewList(dl, pgl.GL_COMPILE)
function()
pgl.glEndList()
return dl
def _render_stack_top(self, render_stack):
top = render_stack[-1]
if top == -1:
return -1 # nothing to display
elif callable(top):
dl = self._create_display_list(top)
render_stack[-1] = (dl, top)
return dl # display newly added list
elif len(top) == 2:
if pgl.GL_TRUE == pgl.glIsList(top[0]):
return top[0] # display stored list
dl = self._create_display_list(top[1])
render_stack[-1] = (dl, top[1])
return dl # display regenerated list
def _draw_solid_display_list(self, dl):
pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT)
pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_FILL)
pgl.glCallList(dl)
pgl.glPopAttrib()
def _draw_wireframe_display_list(self, dl):
pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT)
pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_LINE)
pgl.glEnable(pgl.GL_POLYGON_OFFSET_LINE)
pgl.glPolygonOffset(-0.005, -50.0)
pgl.glCallList(dl)
pgl.glPopAttrib()
@synchronized
def draw(self):
for f in self.predraw:
if callable(f):
f()
if self.style_override:
style = self.styles[self.style_override]
else:
style = self.styles[self._style]
# Draw solid component if style includes solid
if style & 2:
dl = self._render_stack_top(self._draw_solid)
if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl):
self._draw_solid_display_list(dl)
# Draw wireframe component if style includes wireframe
if style & 1:
dl = self._render_stack_top(self._draw_wireframe)
if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl):
self._draw_wireframe_display_list(dl)
for f in self.postdraw:
if callable(f):
f()
def _on_change_color(self, color):
Thread(target=self._calculate_cverts).start()
def _on_calculate(self):
Thread(target=self._calculate_all).start()
def _calculate_all(self):
self._calculate_verts()
self._calculate_cverts()
def _calculate_verts(self):
if self._calculating_verts.is_set():
return
self._calculating_verts.set()
try:
self._on_calculate_verts()
finally:
self._calculating_verts.clear()
if callable(self.bounds_callback):
self.bounds_callback()
def _calculate_cverts(self):
if self._calculating_verts.is_set():
return
while self._calculating_cverts.is_set():
sleep(0) # wait for previous calculation
self._calculating_cverts.set()
try:
self._on_calculate_cverts()
finally:
self._calculating_cverts.clear()
def _get_calculating_verts(self):
return self._calculating_verts.is_set()
def _get_calculating_verts_pos(self):
return self._calculating_verts_pos
def _get_calculating_verts_len(self):
return self._calculating_verts_len
def _get_calculating_cverts(self):
return self._calculating_cverts.is_set()
def _get_calculating_cverts_pos(self):
return self._calculating_cverts_pos
def _get_calculating_cverts_len(self):
return self._calculating_cverts_len
## Property handlers
def _get_style(self):
return self._style
@synchronized
def _set_style(self, v):
if v is None:
return
if v == '':
step_max = 0
for i in self.intervals:
if i.v_steps is None:
continue
step_max = max([step_max, int(i.v_steps)])
v = ['both', 'solid'][step_max > 40]
if v not in self.styles:
raise ValueError("v should be there in self.styles")
if v == self._style:
return
self._style = v
def _get_color(self):
return self._color
@synchronized
def _set_color(self, v):
try:
if v is not None:
if is_sequence(v):
v = ColorScheme(*v)
else:
v = ColorScheme(v)
if repr(v) == repr(self._color):
return
self._on_change_color(v)
self._color = v
except Exception as e:
raise RuntimeError("Color change failed. "
"Reason: %s" % (str(e)))
style = property(_get_style, _set_style)
color = property(_get_color, _set_color)
calculating_verts = property(_get_calculating_verts)
calculating_verts_pos = property(_get_calculating_verts_pos)
calculating_verts_len = property(_get_calculating_verts_len)
calculating_cverts = property(_get_calculating_cverts)
calculating_cverts_pos = property(_get_calculating_cverts_pos)
calculating_cverts_len = property(_get_calculating_cverts_len)
## String representations
def __str__(self):
f = ", ".join(str(d) for d in self.d_vars)
o = "'mode=%s'" % (self.primary_alias)
return ", ".join([f, o])
def __repr__(self):
f = ", ".join(str(d) for d in self.d_vars)
i = ", ".join(str(i) for i in self.intervals)
d = [('mode', self.primary_alias),
('color', str(self.color)),
('style', str(self.style))]
o = "'%s'" % ("; ".join("%s=%s" % (k, v)
for k, v in d if v != 'None'))
return ", ".join([f, i, o])

View File

@ -0,0 +1,209 @@
from sympy.utilities.lambdify import lambdify
from sympy.core.numbers import pi
from sympy.functions import sin, cos
from sympy.plotting.pygletplot.plot_curve import PlotCurve
from sympy.plotting.pygletplot.plot_surface import PlotSurface
from math import sin as p_sin
from math import cos as p_cos
def float_vec3(f):
def inner(*args):
v = f(*args)
return float(v[0]), float(v[1]), float(v[2])
return inner
class Cartesian2D(PlotCurve):
i_vars, d_vars = 'x', 'y'
intervals = [[-5, 5, 100]]
aliases = ['cartesian']
is_default = True
def _get_sympy_evaluator(self):
fy = self.d_vars[0]
x = self.t_interval.v
@float_vec3
def e(_x):
return (_x, fy.subs(x, _x), 0.0)
return e
def _get_lambda_evaluator(self):
fy = self.d_vars[0]
x = self.t_interval.v
return lambdify([x], [x, fy, 0.0])
class Cartesian3D(PlotSurface):
i_vars, d_vars = 'xy', 'z'
intervals = [[-1, 1, 40], [-1, 1, 40]]
aliases = ['cartesian', 'monge']
is_default = True
def _get_sympy_evaluator(self):
fz = self.d_vars[0]
x = self.u_interval.v
y = self.v_interval.v
@float_vec3
def e(_x, _y):
return (_x, _y, fz.subs(x, _x).subs(y, _y))
return e
def _get_lambda_evaluator(self):
fz = self.d_vars[0]
x = self.u_interval.v
y = self.v_interval.v
return lambdify([x, y], [x, y, fz])
class ParametricCurve2D(PlotCurve):
i_vars, d_vars = 't', 'xy'
intervals = [[0, 2*pi, 100]]
aliases = ['parametric']
is_default = True
def _get_sympy_evaluator(self):
fx, fy = self.d_vars
t = self.t_interval.v
@float_vec3
def e(_t):
return (fx.subs(t, _t), fy.subs(t, _t), 0.0)
return e
def _get_lambda_evaluator(self):
fx, fy = self.d_vars
t = self.t_interval.v
return lambdify([t], [fx, fy, 0.0])
class ParametricCurve3D(PlotCurve):
i_vars, d_vars = 't', 'xyz'
intervals = [[0, 2*pi, 100]]
aliases = ['parametric']
is_default = True
def _get_sympy_evaluator(self):
fx, fy, fz = self.d_vars
t = self.t_interval.v
@float_vec3
def e(_t):
return (fx.subs(t, _t), fy.subs(t, _t), fz.subs(t, _t))
return e
def _get_lambda_evaluator(self):
fx, fy, fz = self.d_vars
t = self.t_interval.v
return lambdify([t], [fx, fy, fz])
class ParametricSurface(PlotSurface):
i_vars, d_vars = 'uv', 'xyz'
intervals = [[-1, 1, 40], [-1, 1, 40]]
aliases = ['parametric']
is_default = True
def _get_sympy_evaluator(self):
fx, fy, fz = self.d_vars
u = self.u_interval.v
v = self.v_interval.v
@float_vec3
def e(_u, _v):
return (fx.subs(u, _u).subs(v, _v),
fy.subs(u, _u).subs(v, _v),
fz.subs(u, _u).subs(v, _v))
return e
def _get_lambda_evaluator(self):
fx, fy, fz = self.d_vars
u = self.u_interval.v
v = self.v_interval.v
return lambdify([u, v], [fx, fy, fz])
class Polar(PlotCurve):
i_vars, d_vars = 't', 'r'
intervals = [[0, 2*pi, 100]]
aliases = ['polar']
is_default = False
def _get_sympy_evaluator(self):
fr = self.d_vars[0]
t = self.t_interval.v
def e(_t):
_r = float(fr.subs(t, _t))
return (_r*p_cos(_t), _r*p_sin(_t), 0.0)
return e
def _get_lambda_evaluator(self):
fr = self.d_vars[0]
t = self.t_interval.v
fx, fy = fr*cos(t), fr*sin(t)
return lambdify([t], [fx, fy, 0.0])
class Cylindrical(PlotSurface):
i_vars, d_vars = 'th', 'r'
intervals = [[0, 2*pi, 40], [-1, 1, 20]]
aliases = ['cylindrical', 'polar']
is_default = False
def _get_sympy_evaluator(self):
fr = self.d_vars[0]
t = self.u_interval.v
h = self.v_interval.v
def e(_t, _h):
_r = float(fr.subs(t, _t).subs(h, _h))
return (_r*p_cos(_t), _r*p_sin(_t), _h)
return e
def _get_lambda_evaluator(self):
fr = self.d_vars[0]
t = self.u_interval.v
h = self.v_interval.v
fx, fy = fr*cos(t), fr*sin(t)
return lambdify([t, h], [fx, fy, h])
class Spherical(PlotSurface):
i_vars, d_vars = 'tp', 'r'
intervals = [[0, 2*pi, 40], [0, pi, 20]]
aliases = ['spherical']
is_default = False
def _get_sympy_evaluator(self):
fr = self.d_vars[0]
t = self.u_interval.v
p = self.v_interval.v
def e(_t, _p):
_r = float(fr.subs(t, _t).subs(p, _p))
return (_r*p_cos(_t)*p_sin(_p),
_r*p_sin(_t)*p_sin(_p),
_r*p_cos(_p))
return e
def _get_lambda_evaluator(self):
fr = self.d_vars[0]
t = self.u_interval.v
p = self.v_interval.v
fx = fr * cos(t) * sin(p)
fy = fr * sin(t) * sin(p)
fz = fr * cos(p)
return lambdify([t, p], [fx, fy, fz])
Cartesian2D._register()
Cartesian3D._register()
ParametricCurve2D._register()
ParametricCurve3D._register()
ParametricSurface._register()
Polar._register()
Cylindrical._register()
Spherical._register()

View File

@ -0,0 +1,17 @@
class PlotObject:
"""
Base class for objects which can be displayed in
a Plot.
"""
visible = True
def _draw(self):
if self.visible:
self.draw()
def draw(self):
"""
OpenGL rendering code for the plot object.
Override in base class.
"""
pass

View File

@ -0,0 +1,68 @@
try:
from ctypes import c_float
except ImportError:
pass
import pyglet.gl as pgl
from math import sqrt as _sqrt, acos as _acos
def cross(a, b):
return (a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0])
def dot(a, b):
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
def mag(a):
return _sqrt(a[0]**2 + a[1]**2 + a[2]**2)
def norm(a):
m = mag(a)
return (a[0] / m, a[1] / m, a[2] / m)
def get_sphere_mapping(x, y, width, height):
x = min([max([x, 0]), width])
y = min([max([y, 0]), height])
sr = _sqrt((width/2)**2 + (height/2)**2)
sx = ((x - width / 2) / sr)
sy = ((y - height / 2) / sr)
sz = 1.0 - sx**2 - sy**2
if sz > 0.0:
sz = _sqrt(sz)
return (sx, sy, sz)
else:
sz = 0
return norm((sx, sy, sz))
rad2deg = 180.0 / 3.141592
def get_spherical_rotatation(p1, p2, width, height, theta_multiplier):
v1 = get_sphere_mapping(p1[0], p1[1], width, height)
v2 = get_sphere_mapping(p2[0], p2[1], width, height)
d = min(max([dot(v1, v2), -1]), 1)
if abs(d - 1.0) < 0.000001:
return None
raxis = norm( cross(v1, v2) )
rtheta = theta_multiplier * rad2deg * _acos(d)
pgl.glPushMatrix()
pgl.glLoadIdentity()
pgl.glRotatef(rtheta, *raxis)
mat = (c_float*16)()
pgl.glGetFloatv(pgl.GL_MODELVIEW_MATRIX, mat)
pgl.glPopMatrix()
return mat

View File

@ -0,0 +1,102 @@
import pyglet.gl as pgl
from sympy.core import S
from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase
class PlotSurface(PlotModeBase):
default_rot_preset = 'perspective'
def _on_calculate_verts(self):
self.u_interval = self.intervals[0]
self.u_set = list(self.u_interval.frange())
self.v_interval = self.intervals[1]
self.v_set = list(self.v_interval.frange())
self.bounds = [[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0]]
evaluate = self._get_evaluator()
self._calculating_verts_pos = 0.0
self._calculating_verts_len = float(
self.u_interval.v_len*self.v_interval.v_len)
verts = []
b = self.bounds
for u in self.u_set:
column = []
for v in self.v_set:
try:
_e = evaluate(u, v) # calculate vertex
except ZeroDivisionError:
_e = None
if _e is not None: # update bounding box
for axis in range(3):
b[axis][0] = min([b[axis][0], _e[axis]])
b[axis][1] = max([b[axis][1], _e[axis]])
column.append(_e)
self._calculating_verts_pos += 1.0
verts.append(column)
for axis in range(3):
b[axis][2] = b[axis][1] - b[axis][0]
if b[axis][2] == 0.0:
b[axis][2] = 1.0
self.verts = verts
self.push_wireframe(self.draw_verts(False, False))
self.push_solid(self.draw_verts(False, True))
def _on_calculate_cverts(self):
if not self.verts or not self.color:
return
def set_work_len(n):
self._calculating_cverts_len = float(n)
def inc_work_pos():
self._calculating_cverts_pos += 1.0
set_work_len(1)
self._calculating_cverts_pos = 0
self.cverts = self.color.apply_to_surface(self.verts,
self.u_set,
self.v_set,
set_len=set_work_len,
inc_pos=inc_work_pos)
self.push_solid(self.draw_verts(True, True))
def calculate_one_cvert(self, u, v):
vert = self.verts[u][v]
return self.color(vert[0], vert[1], vert[2],
self.u_set[u], self.v_set[v])
def draw_verts(self, use_cverts, use_solid_color):
def f():
for u in range(1, len(self.u_set)):
pgl.glBegin(pgl.GL_QUAD_STRIP)
for v in range(len(self.v_set)):
pa = self.verts[u - 1][v]
pb = self.verts[u][v]
if pa is None or pb is None:
pgl.glEnd()
pgl.glBegin(pgl.GL_QUAD_STRIP)
continue
if use_cverts:
ca = self.cverts[u - 1][v]
cb = self.cverts[u][v]
if ca is None:
ca = (0, 0, 0)
if cb is None:
cb = (0, 0, 0)
else:
if use_solid_color:
ca = cb = self.default_solid_color
else:
ca = cb = self.default_wireframe_color
pgl.glColor3f(*ca)
pgl.glVertex3f(*pa)
pgl.glColor3f(*cb)
pgl.glVertex3f(*pb)
pgl.glEnd()
return f

View File

@ -0,0 +1,144 @@
from time import perf_counter
import pyglet.gl as pgl
from sympy.plotting.pygletplot.managed_window import ManagedWindow
from sympy.plotting.pygletplot.plot_camera import PlotCamera
from sympy.plotting.pygletplot.plot_controller import PlotController
class PlotWindow(ManagedWindow):
def __init__(self, plot, antialiasing=True, ortho=False,
invert_mouse_zoom=False, linewidth=1.5, caption="SymPy Plot",
**kwargs):
"""
Named Arguments
===============
antialiasing = True
True OR False
ortho = False
True OR False
invert_mouse_zoom = False
True OR False
"""
self.plot = plot
self.camera = None
self._calculating = False
self.antialiasing = antialiasing
self.ortho = ortho
self.invert_mouse_zoom = invert_mouse_zoom
self.linewidth = linewidth
self.title = caption
self.last_caption_update = 0
self.caption_update_interval = 0.2
self.drawing_first_object = True
super().__init__(**kwargs)
def setup(self):
self.camera = PlotCamera(self, ortho=self.ortho)
self.controller = PlotController(self,
invert_mouse_zoom=self.invert_mouse_zoom)
self.push_handlers(self.controller)
pgl.glClearColor(1.0, 1.0, 1.0, 0.0)
pgl.glClearDepth(1.0)
pgl.glDepthFunc(pgl.GL_LESS)
pgl.glEnable(pgl.GL_DEPTH_TEST)
pgl.glEnable(pgl.GL_LINE_SMOOTH)
pgl.glShadeModel(pgl.GL_SMOOTH)
pgl.glLineWidth(self.linewidth)
pgl.glEnable(pgl.GL_BLEND)
pgl.glBlendFunc(pgl.GL_SRC_ALPHA, pgl.GL_ONE_MINUS_SRC_ALPHA)
if self.antialiasing:
pgl.glHint(pgl.GL_LINE_SMOOTH_HINT, pgl.GL_NICEST)
pgl.glHint(pgl.GL_POLYGON_SMOOTH_HINT, pgl.GL_NICEST)
self.camera.setup_projection()
def on_resize(self, w, h):
super().on_resize(w, h)
if self.camera is not None:
self.camera.setup_projection()
def update(self, dt):
self.controller.update(dt)
def draw(self):
self.plot._render_lock.acquire()
self.camera.apply_transformation()
calc_verts_pos, calc_verts_len = 0, 0
calc_cverts_pos, calc_cverts_len = 0, 0
should_update_caption = (perf_counter() - self.last_caption_update >
self.caption_update_interval)
if len(self.plot._functions.values()) == 0:
self.drawing_first_object = True
iterfunctions = iter(self.plot._functions.values())
for r in iterfunctions:
if self.drawing_first_object:
self.camera.set_rot_preset(r.default_rot_preset)
self.drawing_first_object = False
pgl.glPushMatrix()
r._draw()
pgl.glPopMatrix()
# might as well do this while we are
# iterating and have the lock rather
# than locking and iterating twice
# per frame:
if should_update_caption:
try:
if r.calculating_verts:
calc_verts_pos += r.calculating_verts_pos
calc_verts_len += r.calculating_verts_len
if r.calculating_cverts:
calc_cverts_pos += r.calculating_cverts_pos
calc_cverts_len += r.calculating_cverts_len
except ValueError:
pass
for r in self.plot._pobjects:
pgl.glPushMatrix()
r._draw()
pgl.glPopMatrix()
if should_update_caption:
self.update_caption(calc_verts_pos, calc_verts_len,
calc_cverts_pos, calc_cverts_len)
self.last_caption_update = perf_counter()
if self.plot._screenshot:
self.plot._screenshot._execute_saving()
self.plot._render_lock.release()
def update_caption(self, calc_verts_pos, calc_verts_len,
calc_cverts_pos, calc_cverts_len):
caption = self.title
if calc_verts_len or calc_cverts_len:
caption += " (calculating"
if calc_verts_len > 0:
p = (calc_verts_pos / calc_verts_len) * 100
caption += " vertices %i%%" % (p)
if calc_cverts_len > 0:
p = (calc_cverts_pos / calc_cverts_len) * 100
caption += " colors %i%%" % (p)
caption += ")"
if self.caption != caption:
self.set_caption(caption)

View File

@ -0,0 +1,88 @@
from sympy.external.importtools import import_module
disabled = False
# if pyglet.gl fails to import, e.g. opengl is missing, we disable the tests
pyglet_gl = import_module("pyglet.gl", catch=(OSError,))
pyglet_window = import_module("pyglet.window", catch=(OSError,))
if not pyglet_gl or not pyglet_window:
disabled = True
from sympy.core.symbol import symbols
from sympy.functions.elementary.exponential import log
from sympy.functions.elementary.trigonometric import (cos, sin)
x, y, z = symbols('x, y, z')
def test_plot_2d():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(x, [x, -5, 5, 4], visible=False)
p.wait_for_calculations()
def test_plot_2d_discontinuous():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(1/x, [x, -1, 1, 2], visible=False)
p.wait_for_calculations()
def test_plot_3d():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(x*y, [x, -5, 5, 5], [y, -5, 5, 5], visible=False)
p.wait_for_calculations()
def test_plot_3d_discontinuous():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(1/x, [x, -3, 3, 6], [y, -1, 1, 1], visible=False)
p.wait_for_calculations()
def test_plot_2d_polar():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(1/x, [x, -1, 1, 4], 'mode=polar', visible=False)
p.wait_for_calculations()
def test_plot_3d_cylinder():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(
1/y, [x, 0, 6.282, 4], [y, -1, 1, 4], 'mode=polar;style=solid',
visible=False)
p.wait_for_calculations()
def test_plot_3d_spherical():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(
1, [x, 0, 6.282, 4], [y, 0, 3.141,
4], 'mode=spherical;style=wireframe',
visible=False)
p.wait_for_calculations()
def test_plot_2d_parametric():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(sin(x), cos(x), [x, 0, 6.282, 4], visible=False)
p.wait_for_calculations()
def test_plot_3d_parametric():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(sin(x), cos(x), x/5.0, [x, 0, 6.282, 4], visible=False)
p.wait_for_calculations()
def _test_plot_log():
from sympy.plotting.pygletplot import PygletPlot
p = PygletPlot(log(x), [x, 0, 6.282, 4], 'mode=polar', visible=False)
p.wait_for_calculations()
def test_plot_integral():
# Make sure it doesn't treat x as an independent variable
from sympy.plotting.pygletplot import PygletPlot
from sympy.integrals.integrals import Integral
p = PygletPlot(Integral(z*x, (x, 1, z), (z, 1, y)), visible=False)
p.wait_for_calculations()

View File

@ -0,0 +1,188 @@
try:
from ctypes import c_float, c_int, c_double
except ImportError:
pass
import pyglet.gl as pgl
from sympy.core import S
def get_model_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv):
"""
Returns the current modelview matrix.
"""
m = (array_type*16)()
glGetMethod(pgl.GL_MODELVIEW_MATRIX, m)
return m
def get_projection_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv):
"""
Returns the current modelview matrix.
"""
m = (array_type*16)()
glGetMethod(pgl.GL_PROJECTION_MATRIX, m)
return m
def get_viewport():
"""
Returns the current viewport.
"""
m = (c_int*4)()
pgl.glGetIntegerv(pgl.GL_VIEWPORT, m)
return m
def get_direction_vectors():
m = get_model_matrix()
return ((m[0], m[4], m[8]),
(m[1], m[5], m[9]),
(m[2], m[6], m[10]))
def get_view_direction_vectors():
m = get_model_matrix()
return ((m[0], m[1], m[2]),
(m[4], m[5], m[6]),
(m[8], m[9], m[10]))
def get_basis_vectors():
return ((1, 0, 0), (0, 1, 0), (0, 0, 1))
def screen_to_model(x, y, z):
m = get_model_matrix(c_double, pgl.glGetDoublev)
p = get_projection_matrix(c_double, pgl.glGetDoublev)
w = get_viewport()
mx, my, mz = c_double(), c_double(), c_double()
pgl.gluUnProject(x, y, z, m, p, w, mx, my, mz)
return float(mx.value), float(my.value), float(mz.value)
def model_to_screen(x, y, z):
m = get_model_matrix(c_double, pgl.glGetDoublev)
p = get_projection_matrix(c_double, pgl.glGetDoublev)
w = get_viewport()
mx, my, mz = c_double(), c_double(), c_double()
pgl.gluProject(x, y, z, m, p, w, mx, my, mz)
return float(mx.value), float(my.value), float(mz.value)
def vec_subs(a, b):
return tuple(a[i] - b[i] for i in range(len(a)))
def billboard_matrix():
"""
Removes rotational components of
current matrix so that primitives
are always drawn facing the viewer.
|1|0|0|x|
|0|1|0|x|
|0|0|1|x| (x means left unchanged)
|x|x|x|x|
"""
m = get_model_matrix()
# XXX: for i in range(11): m[i] = i ?
m[0] = 1
m[1] = 0
m[2] = 0
m[4] = 0
m[5] = 1
m[6] = 0
m[8] = 0
m[9] = 0
m[10] = 1
pgl.glLoadMatrixf(m)
def create_bounds():
return [[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0],
[S.Infinity, S.NegativeInfinity, 0]]
def update_bounds(b, v):
if v is None:
return
for axis in range(3):
b[axis][0] = min([b[axis][0], v[axis]])
b[axis][1] = max([b[axis][1], v[axis]])
def interpolate(a_min, a_max, a_ratio):
return a_min + a_ratio * (a_max - a_min)
def rinterpolate(a_min, a_max, a_value):
a_range = a_max - a_min
if a_max == a_min:
a_range = 1.0
return (a_value - a_min) / float(a_range)
def interpolate_color(color1, color2, ratio):
return tuple(interpolate(color1[i], color2[i], ratio) for i in range(3))
def scale_value(v, v_min, v_len):
return (v - v_min) / v_len
def scale_value_list(flist):
v_min, v_max = min(flist), max(flist)
v_len = v_max - v_min
return [scale_value(f, v_min, v_len) for f in flist]
def strided_range(r_min, r_max, stride, max_steps=50):
o_min, o_max = r_min, r_max
if abs(r_min - r_max) < 0.001:
return []
try:
range(int(r_min - r_max))
except (TypeError, OverflowError):
return []
if r_min > r_max:
raise ValueError("r_min cannot be greater than r_max")
r_min_s = (r_min % stride)
r_max_s = stride - (r_max % stride)
if abs(r_max_s - stride) < 0.001:
r_max_s = 0.0
r_min -= r_min_s
r_max += r_max_s
r_steps = int((r_max - r_min)/stride)
if max_steps and r_steps > max_steps:
return strided_range(o_min, o_max, stride*2)
return [r_min] + [r_min + e*stride for e in range(1, r_steps + 1)] + [r_max]
def parse_option_string(s):
if not isinstance(s, str):
return None
options = {}
for token in s.split(';'):
pieces = token.split('=')
if len(pieces) == 1:
option, value = pieces[0], ""
elif len(pieces) == 2:
option, value = pieces
else:
raise ValueError("Plot option string '%s' is malformed." % (s))
options[option.strip()] = value.strip()
return options
def dot_product(v1, v2):
return sum(v1[i]*v2[i] for i in range(3))
def vec_sub(v1, v2):
return tuple(v1[i] - v2[i] for i in range(3))
def vec_mag(v):
return sum(v[i]**2 for i in range(3))**(0.5)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,77 @@
from sympy.core.symbol import symbols, Symbol
from sympy.functions import Max
from sympy.plotting.experimental_lambdify import experimental_lambdify
from sympy.plotting.intervalmath.interval_arithmetic import \
interval, intervalMembership
# Tests for exception handling in experimental_lambdify
def test_experimental_lambify():
x = Symbol('x')
f = experimental_lambdify([x], Max(x, 5))
# XXX should f be tested? If f(2) is attempted, an
# error is raised because a complex produced during wrapping of the arg
# is being compared with an int.
assert Max(2, 5) == 5
assert Max(5, 7) == 7
x = Symbol('x-3')
f = experimental_lambdify([x], x + 1)
assert f(1) == 2
def test_composite_boolean_region():
x, y = symbols('x y')
r1 = (x - 1)**2 + y**2 < 2
r2 = (x + 1)**2 + y**2 < 2
f = experimental_lambdify((x, y), r1 & r2)
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(True, True)
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
assert f(*a) == intervalMembership(False, True)
f = experimental_lambdify((x, y), r1 | r2)
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(True, True)
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(True, True)
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(True, True)
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
assert f(*a) == intervalMembership(False, True)
f = experimental_lambdify((x, y), r1 & ~r2)
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(True, True)
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
assert f(*a) == intervalMembership(False, True)
f = experimental_lambdify((x, y), ~r1 & r2)
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(True, True)
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
assert f(*a) == intervalMembership(False, True)
f = experimental_lambdify((x, y), ~r1 & ~r2)
a = (interval(-0.1, 0.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(-1.1, -0.9), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(0.9, 1.1), interval(-0.1, 0.1))
assert f(*a) == intervalMembership(False, True)
a = (interval(-0.1, 0.1), interval(1.9, 2.1))
assert f(*a) == intervalMembership(True, True)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,146 @@
from sympy.core.numbers import (I, pi)
from sympy.core.relational import Eq
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.complexes import re
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.trigonometric import (cos, sin, tan)
from sympy.logic.boolalg import (And, Or)
from sympy.plotting.plot_implicit import plot_implicit
from sympy.plotting.plot import unset_show
from tempfile import NamedTemporaryFile, mkdtemp
from sympy.testing.pytest import skip, warns, XFAIL
from sympy.external import import_module
from sympy.testing.tmpfiles import TmpFileManager
import os
#Set plots not to show
unset_show()
def tmp_file(dir=None, name=''):
return NamedTemporaryFile(
suffix='.png', dir=dir, delete=False).name
def plot_and_save(expr, *args, name='', dir=None, **kwargs):
p = plot_implicit(expr, *args, **kwargs)
p.save(tmp_file(dir=dir, name=name))
# Close the plot to avoid a warning from matplotlib
p._backend.close()
def plot_implicit_tests(name):
temp_dir = mkdtemp()
TmpFileManager.tmp_folder(temp_dir)
x = Symbol('x')
y = Symbol('y')
#implicit plot tests
plot_and_save(Eq(y, cos(x)), (x, -5, 5), (y, -2, 2), name=name, dir=temp_dir)
plot_and_save(Eq(y**2, x**3 - x), (x, -5, 5),
(y, -4, 4), name=name, dir=temp_dir)
plot_and_save(y > 1 / x, (x, -5, 5),
(y, -2, 2), name=name, dir=temp_dir)
plot_and_save(y < 1 / tan(x), (x, -5, 5),
(y, -2, 2), name=name, dir=temp_dir)
plot_and_save(y >= 2 * sin(x) * cos(x), (x, -5, 5),
(y, -2, 2), name=name, dir=temp_dir)
plot_and_save(y <= x**2, (x, -3, 3),
(y, -1, 5), name=name, dir=temp_dir)
#Test all input args for plot_implicit
plot_and_save(Eq(y**2, x**3 - x), dir=temp_dir)
plot_and_save(Eq(y**2, x**3 - x), adaptive=False, dir=temp_dir)
plot_and_save(Eq(y**2, x**3 - x), adaptive=False, n=500, dir=temp_dir)
plot_and_save(y > x, (x, -5, 5), dir=temp_dir)
plot_and_save(And(y > exp(x), y > x + 2), dir=temp_dir)
plot_and_save(Or(y > x, y > -x), dir=temp_dir)
plot_and_save(x**2 - 1, (x, -5, 5), dir=temp_dir)
plot_and_save(x**2 - 1, dir=temp_dir)
plot_and_save(y > x, depth=-5, dir=temp_dir)
plot_and_save(y > x, depth=5, dir=temp_dir)
plot_and_save(y > cos(x), adaptive=False, dir=temp_dir)
plot_and_save(y < cos(x), adaptive=False, dir=temp_dir)
plot_and_save(And(y > cos(x), Or(y > x, Eq(y, x))), dir=temp_dir)
plot_and_save(y - cos(pi / x), dir=temp_dir)
plot_and_save(x**2 - 1, title='An implicit plot', dir=temp_dir)
@XFAIL
def test_no_adaptive_meshing():
matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,))
if matplotlib:
try:
temp_dir = mkdtemp()
TmpFileManager.tmp_folder(temp_dir)
x = Symbol('x')
y = Symbol('y')
# Test plots which cannot be rendered using the adaptive algorithm
# This works, but it triggers a deprecation warning from sympify(). The
# code needs to be updated to detect if interval math is supported without
# relying on random AttributeErrors.
with warns(UserWarning, match="Adaptive meshing could not be applied"):
plot_and_save(Eq(y, re(cos(x) + I*sin(x))), name='test', dir=temp_dir)
finally:
TmpFileManager.cleanup()
else:
skip("Matplotlib not the default backend")
def test_line_color():
x, y = symbols('x, y')
p = plot_implicit(x**2 + y**2 - 1, line_color="green", show=False)
assert p._series[0].line_color == "green"
p = plot_implicit(x**2 + y**2 - 1, line_color='r', show=False)
assert p._series[0].line_color == "r"
def test_matplotlib():
matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,))
if matplotlib:
try:
plot_implicit_tests('test')
test_line_color()
finally:
TmpFileManager.cleanup()
else:
skip("Matplotlib not the default backend")
def test_region_and():
matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,))
if not matplotlib:
skip("Matplotlib not the default backend")
from matplotlib.testing.compare import compare_images
test_directory = os.path.dirname(os.path.abspath(__file__))
try:
temp_dir = mkdtemp()
TmpFileManager.tmp_folder(temp_dir)
x, y = symbols('x y')
r1 = (x - 1)**2 + y**2 < 2
r2 = (x + 1)**2 + y**2 < 2
test_filename = tmp_file(dir=temp_dir, name="test_region_and")
cmp_filename = os.path.join(test_directory, "test_region_and.png")
p = plot_implicit(r1 & r2, x, y)
p.save(test_filename)
compare_images(cmp_filename, test_filename, 0.005)
test_filename = tmp_file(dir=temp_dir, name="test_region_or")
cmp_filename = os.path.join(test_directory, "test_region_or.png")
p = plot_implicit(r1 | r2, x, y)
p.save(test_filename)
compare_images(cmp_filename, test_filename, 0.005)
test_filename = tmp_file(dir=temp_dir, name="test_region_not")
cmp_filename = os.path.join(test_directory, "test_region_not.png")
p = plot_implicit(~r1, x, y)
p.save(test_filename)
compare_images(cmp_filename, test_filename, 0.005)
test_filename = tmp_file(dir=temp_dir, name="test_region_xor")
cmp_filename = os.path.join(test_directory, "test_region_xor.png")
p = plot_implicit(r1 ^ r2, x, y)
p.save(test_filename)
compare_images(cmp_filename, test_filename, 0.005)
finally:
TmpFileManager.cleanup()

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.7 KiB

View File

@ -0,0 +1,34 @@
[remap]
importer="texture"
type="CompressedTexture2D"
uid="uid://ccgrfp1kviy0t"
path="res://.godot/imported/test_region_and.png-c1d9a0695d86ddf4348962c5b17879ff.ctex"
metadata={
"vram_texture": false
}
[deps]
source_file="res://rl/Lib/site-packages/sympy/plotting/tests/test_region_and.png"
dest_files=["res://.godot/imported/test_region_and.png-c1d9a0695d86ddf4348962c5b17879ff.ctex"]
[params]
compress/mode=0
compress/high_quality=false
compress/lossy_quality=0.7
compress/hdr_compression=1
compress/normal_map=0
compress/channel_pack=0
mipmaps/generate=false
mipmaps/limit=-1
roughness/mode=0
roughness/src_normal=""
process/fix_alpha_border=true
process/premult_alpha=false
process/normal_map_invert_y=false
process/hdr_as_srgb=false
process/hdr_clamp_exposure=false
process/size_limit=0
detect_3d/compress_to=1

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

View File

@ -0,0 +1,34 @@
[remap]
importer="texture"
type="CompressedTexture2D"
uid="uid://bwcwpn78v6jsw"
path="res://.godot/imported/test_region_not.png-f3250ebd2c1022cb029e786ee370fd83.ctex"
metadata={
"vram_texture": false
}
[deps]
source_file="res://rl/Lib/site-packages/sympy/plotting/tests/test_region_not.png"
dest_files=["res://.godot/imported/test_region_not.png-f3250ebd2c1022cb029e786ee370fd83.ctex"]
[params]
compress/mode=0
compress/high_quality=false
compress/lossy_quality=0.7
compress/hdr_compression=1
compress/normal_map=0
compress/channel_pack=0
mipmaps/generate=false
mipmaps/limit=-1
roughness/mode=0
roughness/src_normal=""
process/fix_alpha_border=true
process/premult_alpha=false
process/normal_map_invert_y=false
process/hdr_as_srgb=false
process/hdr_clamp_exposure=false
process/size_limit=0
detect_3d/compress_to=1

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

View File

@ -0,0 +1,34 @@
[remap]
importer="texture"
type="CompressedTexture2D"
uid="uid://ptn7si3a6w3e"
path="res://.godot/imported/test_region_or.png-fcd19de9f7f9d11aa222486f596ec33d.ctex"
metadata={
"vram_texture": false
}
[deps]
source_file="res://rl/Lib/site-packages/sympy/plotting/tests/test_region_or.png"
dest_files=["res://.godot/imported/test_region_or.png-fcd19de9f7f9d11aa222486f596ec33d.ctex"]
[params]
compress/mode=0
compress/high_quality=false
compress/lossy_quality=0.7
compress/hdr_compression=1
compress/normal_map=0
compress/channel_pack=0
mipmaps/generate=false
mipmaps/limit=-1
roughness/mode=0
roughness/src_normal=""
process/fix_alpha_border=true
process/premult_alpha=false
process/normal_map_invert_y=false
process/hdr_as_srgb=false
process/hdr_clamp_exposure=false
process/size_limit=0
detect_3d/compress_to=1

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.8 KiB

View File

@ -0,0 +1,34 @@
[remap]
importer="texture"
type="CompressedTexture2D"
uid="uid://dp0upq1njbej6"
path="res://.godot/imported/test_region_xor.png-b5f7c27e2fb8694466c39ee95db567dc.ctex"
metadata={
"vram_texture": false
}
[deps]
source_file="res://rl/Lib/site-packages/sympy/plotting/tests/test_region_xor.png"
dest_files=["res://.godot/imported/test_region_xor.png-b5f7c27e2fb8694466c39ee95db567dc.ctex"]
[params]
compress/mode=0
compress/high_quality=false
compress/lossy_quality=0.7
compress/hdr_compression=1
compress/normal_map=0
compress/channel_pack=0
mipmaps/generate=false
mipmaps/limit=-1
roughness/mode=0
roughness/src_normal=""
process/fix_alpha_border=true
process/premult_alpha=false
process/normal_map_invert_y=false
process/hdr_as_srgb=false
process/hdr_clamp_exposure=false
process/size_limit=0
detect_3d/compress_to=1

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More