I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,116 @@
"""Printing subsystem"""
from .pretty import pager_print, pretty, pretty_print, pprint, pprint_use_unicode, pprint_try_use_unicode
from .latex import latex, print_latex, multiline_latex
from .mathml import mathml, print_mathml
from .python import python, print_python
from .pycode import pycode
from .codeprinter import print_ccode, print_fcode
from .codeprinter import ccode, fcode, cxxcode # noqa:F811
from .smtlib import smtlib_code
from .glsl import glsl_code, print_glsl
from .rcode import rcode, print_rcode
from .jscode import jscode, print_jscode
from .julia import julia_code
from .mathematica import mathematica_code
from .octave import octave_code
from .rust import rust_code
from .gtk import print_gtk
from .preview import preview
from .repr import srepr
from .tree import print_tree
from .str import StrPrinter, sstr, sstrrepr
from .tableform import TableForm
from .dot import dotprint
from .maple import maple_code, print_maple_code
__all__ = [
# sympy.printing.pretty
'pager_print', 'pretty', 'pretty_print', 'pprint', 'pprint_use_unicode',
'pprint_try_use_unicode',
# sympy.printing.latex
'latex', 'print_latex', 'multiline_latex',
# sympy.printing.mathml
'mathml', 'print_mathml',
# sympy.printing.python
'python', 'print_python',
# sympy.printing.pycode
'pycode',
# sympy.printing.codeprinter
'ccode', 'print_ccode', 'cxxcode', 'fcode', 'print_fcode',
# sympy.printing.smtlib
'smtlib_code',
# sympy.printing.glsl
'glsl_code', 'print_glsl',
# sympy.printing.rcode
'rcode', 'print_rcode',
# sympy.printing.jscode
'jscode', 'print_jscode',
# sympy.printing.julia
'julia_code',
# sympy.printing.mathematica
'mathematica_code',
# sympy.printing.octave
'octave_code',
# sympy.printing.rust
'rust_code',
# sympy.printing.gtk
'print_gtk',
# sympy.printing.preview
'preview',
# sympy.printing.repr
'srepr',
# sympy.printing.tree
'print_tree',
# sympy.printing.str
'StrPrinter', 'sstr', 'sstrrepr',
# sympy.printing.tableform
'TableForm',
# sympy.printing.dot
'dotprint',
# sympy.printing.maple
'maple_code', 'print_maple_code',
]

View File

@ -0,0 +1,540 @@
from __future__ import annotations
from typing import Any
from sympy.external import import_module
from sympy.printing.printer import Printer
from sympy.utilities.iterables import is_sequence
import sympy
from functools import partial
aesara = import_module('aesara')
if aesara:
aes = aesara.scalar
aet = aesara.tensor
from aesara.tensor import nlinalg
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.elemwise import DimShuffle
# `true_divide` replaced `true_div` in Aesara 2.8.11 (released 2023) to
# match NumPy
# XXX: Remove this when not needed to support older versions.
true_divide = getattr(aet, 'true_divide', None)
if true_divide is None:
true_divide = aet.true_div
mapping = {
sympy.Add: aet.add,
sympy.Mul: aet.mul,
sympy.Abs: aet.abs,
sympy.sign: aet.sgn,
sympy.ceiling: aet.ceil,
sympy.floor: aet.floor,
sympy.log: aet.log,
sympy.exp: aet.exp,
sympy.sqrt: aet.sqrt,
sympy.cos: aet.cos,
sympy.acos: aet.arccos,
sympy.sin: aet.sin,
sympy.asin: aet.arcsin,
sympy.tan: aet.tan,
sympy.atan: aet.arctan,
sympy.atan2: aet.arctan2,
sympy.cosh: aet.cosh,
sympy.acosh: aet.arccosh,
sympy.sinh: aet.sinh,
sympy.asinh: aet.arcsinh,
sympy.tanh: aet.tanh,
sympy.atanh: aet.arctanh,
sympy.re: aet.real,
sympy.im: aet.imag,
sympy.arg: aet.angle,
sympy.erf: aet.erf,
sympy.gamma: aet.gamma,
sympy.loggamma: aet.gammaln,
sympy.Pow: aet.pow,
sympy.Eq: aet.eq,
sympy.StrictGreaterThan: aet.gt,
sympy.StrictLessThan: aet.lt,
sympy.LessThan: aet.le,
sympy.GreaterThan: aet.ge,
sympy.And: aet.bitwise_and, # bitwise
sympy.Or: aet.bitwise_or, # bitwise
sympy.Not: aet.invert, # bitwise
sympy.Xor: aet.bitwise_xor, # bitwise
sympy.Max: aet.maximum, # Sympy accept >2 inputs, Aesara only 2
sympy.Min: aet.minimum, # Sympy accept >2 inputs, Aesara only 2
sympy.conjugate: aet.conj,
sympy.core.numbers.ImaginaryUnit: lambda:aet.complex(0,1),
# Matrices
sympy.MatAdd: Elemwise(aes.add),
sympy.HadamardProduct: Elemwise(aes.mul),
sympy.Trace: nlinalg.trace,
sympy.Determinant : nlinalg.det,
sympy.Inverse: nlinalg.matrix_inverse,
sympy.Transpose: DimShuffle((False, False), [1, 0]),
}
class AesaraPrinter(Printer):
""" Code printer which creates Aesara symbolic expression graphs.
Parameters
==========
cache : dict
Cache dictionary to use. If None (default) will use
the global cache. To create a printer which does not depend on or alter
global state pass an empty dictionary. Note: the dictionary is not
copied on initialization of the printer and will be updated in-place,
so using the same dict object when creating multiple printers or making
multiple calls to :func:`.aesara_code` or :func:`.aesara_function` means
the cache is shared between all these applications.
Attributes
==========
cache : dict
A cache of Aesara variables which have been created for SymPy
symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
:class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
ensure that all references to a given symbol in an expression (or
multiple expressions) are printed as the same Aesara variable, which is
created only once. Symbols are differentiated only by name and type. The
format of the cache's contents should be considered opaque to the user.
"""
printmethod = "_aesara"
def __init__(self, *args, **kwargs):
self.cache = kwargs.pop('cache', {})
super().__init__(*args, **kwargs)
def _get_key(self, s, name=None, dtype=None, broadcastable=None):
""" Get the cache key for a SymPy object.
Parameters
==========
s : sympy.core.basic.Basic
SymPy object to get key for.
name : str
Name of object, if it does not have a ``name`` attribute.
"""
if name is None:
name = s.name
return (name, type(s), s.args, dtype, broadcastable)
def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
"""
Get the Aesara variable for a SymPy symbol from the cache, or create it
if it does not exist.
"""
# Defaults
if name is None:
name = s.name
if dtype is None:
dtype = 'floatX'
if broadcastable is None:
broadcastable = ()
key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
if key in self.cache:
return self.cache[key]
value = aet.tensor(name=name, dtype=dtype, shape=broadcastable)
self.cache[key] = value
return value
def _print_Symbol(self, s, **kwargs):
dtype = kwargs.get('dtypes', {}).get(s)
bc = kwargs.get('broadcastables', {}).get(s)
return self._get_or_create(s, dtype=dtype, broadcastable=bc)
def _print_AppliedUndef(self, s, **kwargs):
name = str(type(s)) + '_' + str(s.args[0])
dtype = kwargs.get('dtypes', {}).get(s)
bc = kwargs.get('broadcastables', {}).get(s)
return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
def _print_Basic(self, expr, **kwargs):
op = mapping[type(expr)]
children = [self._print(arg, **kwargs) for arg in expr.args]
return op(*children)
def _print_Number(self, n, **kwargs):
# Integers already taken care of below, interpret as float
return float(n.evalf())
def _print_MatrixSymbol(self, X, **kwargs):
dtype = kwargs.get('dtypes', {}).get(X)
return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
def _print_DenseMatrix(self, X, **kwargs):
if not hasattr(aet, 'stacklists'):
raise NotImplementedError(
"Matrix translation not yet supported in this version of Aesara")
return aet.stacklists([
[self._print(arg, **kwargs) for arg in L]
for L in X.tolist()
])
_print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
def _print_MatMul(self, expr, **kwargs):
children = [self._print(arg, **kwargs) for arg in expr.args]
result = children[0]
for child in children[1:]:
result = aet.dot(result, child)
return result
def _print_MatPow(self, expr, **kwargs):
children = [self._print(arg, **kwargs) for arg in expr.args]
result = 1
if isinstance(children[1], int) and children[1] > 0:
for i in range(children[1]):
result = aet.dot(result, children[0])
else:
raise NotImplementedError('''Only non-negative integer
powers of matrices can be handled by Aesara at the moment''')
return result
def _print_MatrixSlice(self, expr, **kwargs):
parent = self._print(expr.parent, **kwargs)
rowslice = self._print(slice(*expr.rowslice), **kwargs)
colslice = self._print(slice(*expr.colslice), **kwargs)
return parent[rowslice, colslice]
def _print_BlockMatrix(self, expr, **kwargs):
nrows, ncols = expr.blocks.shape
blocks = [[self._print(expr.blocks[r, c], **kwargs)
for c in range(ncols)]
for r in range(nrows)]
return aet.join(0, *[aet.join(1, *row) for row in blocks])
def _print_slice(self, expr, **kwargs):
return slice(*[self._print(i, **kwargs)
if isinstance(i, sympy.Basic) else i
for i in (expr.start, expr.stop, expr.step)])
def _print_Pi(self, expr, **kwargs):
return 3.141592653589793
def _print_Piecewise(self, expr, **kwargs):
import numpy as np
e, cond = expr.args[0].args # First condition and corresponding value
# Print conditional expression and value for first condition
p_cond = self._print(cond, **kwargs)
p_e = self._print(e, **kwargs)
# One condition only
if len(expr.args) == 1:
# Return value if condition else NaN
return aet.switch(p_cond, p_e, np.nan)
# Return value_1 if condition_1 else evaluate remaining conditions
p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
return aet.switch(p_cond, p_e, p_remaining)
def _print_Rational(self, expr, **kwargs):
return true_divide(self._print(expr.p, **kwargs),
self._print(expr.q, **kwargs))
def _print_Integer(self, expr, **kwargs):
return expr.p
def _print_factorial(self, expr, **kwargs):
return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
def _print_Derivative(self, deriv, **kwargs):
from aesara.gradient import Rop
rv = self._print(deriv.expr, **kwargs)
for var in deriv.variables:
var = self._print(var, **kwargs)
rv = Rop(rv, var, aet.ones_like(var))
return rv
def emptyPrinter(self, expr):
return expr
def doprint(self, expr, dtypes=None, broadcastables=None):
""" Convert a SymPy expression to a Aesara graph variable.
The ``dtypes`` and ``broadcastables`` arguments are used to specify the
data type, dimension, and broadcasting behavior of the Aesara variables
corresponding to the free symbols in ``expr``. Each is a mapping from
SymPy symbols to the value of the corresponding argument to
``aesara.tensor.var.TensorVariable``.
See the corresponding `documentation page`__ for more information on
broadcasting in Aesara.
.. __: https://aesara.readthedocs.io/en/latest/reference/tensor/shapes.html#broadcasting
Parameters
==========
expr : sympy.core.expr.Expr
SymPy expression to print.
dtypes : dict
Mapping from SymPy symbols to Aesara datatypes to use when creating
new Aesara variables for those symbols. Corresponds to the ``dtype``
argument to ``aesara.tensor.var.TensorVariable``. Defaults to ``'floatX'``
for symbols not included in the mapping.
broadcastables : dict
Mapping from SymPy symbols to the value of the ``broadcastable``
argument to ``aesara.tensor.var.TensorVariable`` to use when creating Aesara
variables for those symbols. Defaults to the empty tuple for symbols
not included in the mapping (resulting in a scalar).
Returns
=======
aesara.graph.basic.Variable
A variable corresponding to the expression's value in a Aesara
symbolic expression graph.
"""
if dtypes is None:
dtypes = {}
if broadcastables is None:
broadcastables = {}
return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
global_cache: dict[Any, Any] = {}
def aesara_code(expr, cache=None, **kwargs):
"""
Convert a SymPy expression into a Aesara graph variable.
Parameters
==========
expr : sympy.core.expr.Expr
SymPy expression object to convert.
cache : dict
Cached Aesara variables (see :class:`AesaraPrinter.cache
<AesaraPrinter>`). Defaults to the module-level global cache.
dtypes : dict
Passed to :meth:`.AesaraPrinter.doprint`.
broadcastables : dict
Passed to :meth:`.AesaraPrinter.doprint`.
Returns
=======
aesara.graph.basic.Variable
A variable corresponding to the expression's value in a Aesara symbolic
expression graph.
"""
if not aesara:
raise ImportError("aesara is required for aesara_code")
if cache is None:
cache = global_cache
return AesaraPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
r"""
Get value of ``broadcastables`` argument to :func:`.aesara_code` from
keyword arguments to :func:`.aesara_function`.
Included for backwards compatibility.
Parameters
==========
inputs
Sequence of input symbols.
dim : int
Common number of dimensions for all inputs. Overrides other arguments
if given.
dims : dict
Mapping from input symbols to number of dimensions. Overrides
``broadcastables`` argument if given.
broadcastables : dict
Explicit value of ``broadcastables`` argument to
:meth:`.AesaraPrinter.doprint`. If not None function will return this value unchanged.
Returns
=======
dict
Dictionary mapping elements of ``inputs`` to their "broadcastable"
values (tuple of ``bool``\ s).
"""
if dim is not None:
return dict.fromkeys(inputs, (False,) * dim)
if dims is not None:
maxdim = max(dims.values())
return {
s: (False,) * d + (True,) * (maxdim - d)
for s, d in dims.items()
}
if broadcastables is not None:
return broadcastables
return {}
def aesara_function(inputs, outputs, scalar=False, *,
dim=None, dims=None, broadcastables=None, **kwargs):
"""
Create a Aesara function from SymPy expressions.
The inputs and outputs are converted to Aesara variables using
:func:`.aesara_code` and then passed to ``aesara.function``.
Parameters
==========
inputs
Sequence of symbols which constitute the inputs of the function.
outputs
Sequence of expressions which constitute the outputs(s) of the
function. The free symbols of each expression must be a subset of
``inputs``.
scalar : bool
Convert 0-dimensional arrays in output to scalars. This will return a
Python wrapper function around the Aesara function object.
cache : dict
Cached Aesara variables (see :class:`AesaraPrinter.cache
<AesaraPrinter>`). Defaults to the module-level global cache.
dtypes : dict
Passed to :meth:`.AesaraPrinter.doprint`.
broadcastables : dict
Passed to :meth:`.AesaraPrinter.doprint`.
dims : dict
Alternative to ``broadcastables`` argument. Mapping from elements of
``inputs`` to integers indicating the dimension of their associated
arrays/tensors. Overrides ``broadcastables`` argument if given.
dim : int
Another alternative to the ``broadcastables`` argument. Common number of
dimensions to use for all arrays/tensors.
``aesara_function([x, y], [...], dim=2)`` is equivalent to using
``broadcastables={x: (False, False), y: (False, False)}``.
Returns
=======
callable
A callable object which takes values of ``inputs`` as positional
arguments and returns an output array for each of the expressions
in ``outputs``. If ``outputs`` is a single expression the function will
return a Numpy array, if it is a list of multiple expressions the
function will return a list of arrays. See description of the ``squeeze``
argument above for the behavior when a single output is passed in a list.
The returned object will either be an instance of
``aesara.compile.function.types.Function`` or a Python wrapper
function around one. In both cases, the returned value will have a
``aesara_function`` attribute which points to the return value of
``aesara.function``.
Examples
========
>>> from sympy.abc import x, y, z
>>> from sympy.printing.aesaracode import aesara_function
A simple function with one input and one output:
>>> f1 = aesara_function([x], [x**2 - 1], scalar=True)
>>> f1(3)
8.0
A function with multiple inputs and one output:
>>> f2 = aesara_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
>>> f2(3, 4, 2)
5.0
A function with multiple inputs and multiple outputs:
>>> f3 = aesara_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
>>> f3(2, 3)
[13.0, -5.0]
See also
========
dim_handling
"""
if not aesara:
raise ImportError("Aesara is required for aesara_function")
# Pop off non-aesara keyword args
cache = kwargs.pop('cache', {})
dtypes = kwargs.pop('dtypes', {})
broadcastables = dim_handling(
inputs, dim=dim, dims=dims, broadcastables=broadcastables,
)
# Print inputs/outputs
code = partial(aesara_code, cache=cache, dtypes=dtypes,
broadcastables=broadcastables)
tinputs = list(map(code, inputs))
toutputs = list(map(code, outputs))
#fix constant expressions as variables
toutputs = [output if isinstance(output, aesara.graph.basic.Variable) else aet.as_tensor_variable(output) for output in toutputs]
if len(toutputs) == 1:
toutputs = toutputs[0]
# Compile aesara func
func = aesara.function(tinputs, toutputs, **kwargs)
is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
# No wrapper required
if not scalar or not any(is_0d):
func.aesara_function = func
return func
# Create wrapper to convert 0-dimensional outputs to scalars
def wrapper(*args):
out = func(*args)
# out can be array(1.0) or [array(1.0), array(2.0)]
if is_sequence(out):
return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
else:
return out[()]
wrapper.__wrapped__ = func
wrapper.__doc__ = func.__doc__
wrapper.aesara_function = func
return wrapper

View File

@ -0,0 +1,750 @@
"""
C code printer
The C89CodePrinter & C99CodePrinter converts single SymPy expressions into
single C expressions, using the functions defined in math.h where possible.
A complete code generator, which uses ccode extensively, can be found in
sympy.utilities.codegen. The codegen module can be used to generate complete
source code files that are compilable without further modifications.
"""
from __future__ import annotations
from typing import Any
from functools import wraps
from itertools import chain
from sympy.core import S
from sympy.core.numbers import equal_valued, Float
from sympy.codegen.ast import (
Assignment, Pointer, Variable, Declaration, Type,
real, complex_, integer, bool_, float32, float64, float80,
complex64, complex128, intc, value_const, pointer_const,
int8, int16, int32, int64, uint8, uint16, uint32, uint64, untyped,
none
)
from sympy.printing.codeprinter import CodePrinter, requires
from sympy.printing.precedence import precedence, PRECEDENCE
from sympy.sets.fancysets import Range
# These are defined in the other file so we can avoid importing sympy.codegen
# from the top-level 'import sympy'. Export them here as well.
from sympy.printing.codeprinter import ccode, print_ccode # noqa:F401
# dictionary mapping SymPy function to (argument_conditions, C_function).
# Used in C89CodePrinter._print_Function(self)
known_functions_C89 = {
"Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
"sin": "sin",
"cos": "cos",
"tan": "tan",
"asin": "asin",
"acos": "acos",
"atan": "atan",
"atan2": "atan2",
"exp": "exp",
"log": "log",
"sinh": "sinh",
"cosh": "cosh",
"tanh": "tanh",
"floor": "floor",
"ceiling": "ceil",
"sqrt": "sqrt", # To enable automatic rewrites
}
known_functions_C99 = dict(known_functions_C89, **{
'exp2': 'exp2',
'expm1': 'expm1',
'log10': 'log10',
'log2': 'log2',
'log1p': 'log1p',
'Cbrt': 'cbrt',
'hypot': 'hypot',
'fma': 'fma',
'loggamma': 'lgamma',
'erfc': 'erfc',
'Max': 'fmax',
'Min': 'fmin',
"asinh": "asinh",
"acosh": "acosh",
"atanh": "atanh",
"erf": "erf",
"gamma": "tgamma",
})
# These are the core reserved words in the C language. Taken from:
# https://en.cppreference.com/w/c/keyword
reserved_words = [
'auto', 'break', 'case', 'char', 'const', 'continue', 'default', 'do',
'double', 'else', 'enum', 'extern', 'float', 'for', 'goto', 'if', 'int',
'long', 'register', 'return', 'short', 'signed', 'sizeof', 'static',
'struct', 'entry', # never standardized, we'll leave it here anyway
'switch', 'typedef', 'union', 'unsigned', 'void', 'volatile', 'while'
]
reserved_words_c99 = ['inline', 'restrict']
def get_math_macros():
""" Returns a dictionary with math-related macros from math.h/cmath
Note that these macros are not strictly required by the C/C++-standard.
For MSVC they are enabled by defining "_USE_MATH_DEFINES" (preferably
via a compilation flag).
Returns
=======
Dictionary mapping SymPy expressions to strings (macro names)
"""
from sympy.codegen.cfunctions import log2, Sqrt
from sympy.functions.elementary.exponential import log
from sympy.functions.elementary.miscellaneous import sqrt
return {
S.Exp1: 'M_E',
log2(S.Exp1): 'M_LOG2E',
1/log(2): 'M_LOG2E',
log(2): 'M_LN2',
log(10): 'M_LN10',
S.Pi: 'M_PI',
S.Pi/2: 'M_PI_2',
S.Pi/4: 'M_PI_4',
1/S.Pi: 'M_1_PI',
2/S.Pi: 'M_2_PI',
2/sqrt(S.Pi): 'M_2_SQRTPI',
2/Sqrt(S.Pi): 'M_2_SQRTPI',
sqrt(2): 'M_SQRT2',
Sqrt(2): 'M_SQRT2',
1/sqrt(2): 'M_SQRT1_2',
1/Sqrt(2): 'M_SQRT1_2'
}
def _as_macro_if_defined(meth):
""" Decorator for printer methods
When a Printer's method is decorated using this decorator the expressions printed
will first be looked for in the attribute ``math_macros``, and if present it will
print the macro name in ``math_macros`` followed by a type suffix for the type
``real``. e.g. printing ``sympy.pi`` would print ``M_PIl`` if real is mapped to float80.
"""
@wraps(meth)
def _meth_wrapper(self, expr, **kwargs):
if expr in self.math_macros:
return '%s%s' % (self.math_macros[expr], self._get_math_macro_suffix(real))
else:
return meth(self, expr, **kwargs)
return _meth_wrapper
class C89CodePrinter(CodePrinter):
"""A printer to convert Python expressions to strings of C code"""
printmethod = "_ccode"
language = "C"
standard = "C89"
reserved_words = set(reserved_words)
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
'precision': 17,
'user_functions': {},
'contract': True,
'dereference': set(),
'error_on_reserved': False,
})
type_aliases = {
real: float64,
complex_: complex128,
integer: intc
}
type_mappings: dict[Type, Any] = {
real: 'double',
intc: 'int',
float32: 'float',
float64: 'double',
integer: 'int',
bool_: 'bool',
int8: 'int8_t',
int16: 'int16_t',
int32: 'int32_t',
int64: 'int64_t',
uint8: 'int8_t',
uint16: 'int16_t',
uint32: 'int32_t',
uint64: 'int64_t',
}
type_headers = {
bool_: {'stdbool.h'},
int8: {'stdint.h'},
int16: {'stdint.h'},
int32: {'stdint.h'},
int64: {'stdint.h'},
uint8: {'stdint.h'},
uint16: {'stdint.h'},
uint32: {'stdint.h'},
uint64: {'stdint.h'},
}
# Macros needed to be defined when using a Type
type_macros: dict[Type, tuple[str, ...]] = {}
type_func_suffixes = {
float32: 'f',
float64: '',
float80: 'l'
}
type_literal_suffixes = {
float32: 'F',
float64: '',
float80: 'L'
}
type_math_macro_suffixes = {
float80: 'l'
}
math_macros = None
_ns = '' # namespace, C++ uses 'std::'
# known_functions-dict to copy
_kf: dict[str, Any] = known_functions_C89
def __init__(self, settings=None):
settings = settings or {}
if self.math_macros is None:
self.math_macros = settings.pop('math_macros', get_math_macros())
self.type_aliases = dict(chain(self.type_aliases.items(),
settings.pop('type_aliases', {}).items()))
self.type_mappings = dict(chain(self.type_mappings.items(),
settings.pop('type_mappings', {}).items()))
self.type_headers = dict(chain(self.type_headers.items(),
settings.pop('type_headers', {}).items()))
self.type_macros = dict(chain(self.type_macros.items(),
settings.pop('type_macros', {}).items()))
self.type_func_suffixes = dict(chain(self.type_func_suffixes.items(),
settings.pop('type_func_suffixes', {}).items()))
self.type_literal_suffixes = dict(chain(self.type_literal_suffixes.items(),
settings.pop('type_literal_suffixes', {}).items()))
self.type_math_macro_suffixes = dict(chain(self.type_math_macro_suffixes.items(),
settings.pop('type_math_macro_suffixes', {}).items()))
super().__init__(settings)
self.known_functions = dict(self._kf, **settings.get('user_functions', {}))
self._dereference = set(settings.get('dereference', []))
self.headers = set()
self.libraries = set()
self.macros = set()
def _rate_index_position(self, p):
return p*5
def _get_statement(self, codestring):
""" Get code string as a statement - i.e. ending with a semicolon. """
return codestring if codestring.endswith(';') else codestring + ';'
def _get_comment(self, text):
return "/* {} */".format(text)
def _declare_number_const(self, name, value):
type_ = self.type_aliases[real]
var = Variable(name, type=type_, value=value.evalf(type_.decimal_dig), attrs={value_const})
decl = Declaration(var)
return self._get_statement(self._print(decl))
def _format_code(self, lines):
return self.indent_code(lines)
def _traverse_matrix_indices(self, mat):
rows, cols = mat.shape
return ((i, j) for i in range(rows) for j in range(cols))
@_as_macro_if_defined
def _print_Mul(self, expr, **kwargs):
return super()._print_Mul(expr, **kwargs)
@_as_macro_if_defined
def _print_Pow(self, expr):
if "Pow" in self.known_functions:
return self._print_Function(expr)
PREC = precedence(expr)
suffix = self._get_func_suffix(real)
if equal_valued(expr.exp, -1):
return '%s/%s' % (self._print_Float(Float(1.0)), self.parenthesize(expr.base, PREC))
elif equal_valued(expr.exp, 0.5):
return '%ssqrt%s(%s)' % (self._ns, suffix, self._print(expr.base))
elif expr.exp == S.One/3 and self.standard != 'C89':
return '%scbrt%s(%s)' % (self._ns, suffix, self._print(expr.base))
else:
return '%spow%s(%s, %s)' % (self._ns, suffix, self._print(expr.base),
self._print(expr.exp))
def _print_Mod(self, expr):
num, den = expr.args
if num.is_integer and den.is_integer:
PREC = precedence(expr)
snum, sden = [self.parenthesize(arg, PREC) for arg in expr.args]
# % is remainder (same sign as numerator), not modulo (same sign as
# denominator), in C. Hence, % only works as modulo if both numbers
# have the same sign
if (num.is_nonnegative and den.is_nonnegative or
num.is_nonpositive and den.is_nonpositive):
return f"{snum} % {sden}"
return f"(({snum} % {sden}) + {sden}) % {sden}"
# Not guaranteed integer
return self._print_math_func(expr, known='fmod')
def _print_Rational(self, expr):
p, q = int(expr.p), int(expr.q)
suffix = self._get_literal_suffix(real)
return '%d.0%s/%d.0%s' % (p, suffix, q, suffix)
def _print_Indexed(self, expr):
# calculate index for 1d array
offset = getattr(expr.base, 'offset', S.Zero)
strides = getattr(expr.base, 'strides', None)
indices = expr.indices
if strides is None or isinstance(strides, str):
dims = expr.shape
shift = S.One
temp = ()
if strides == 'C' or strides is None:
traversal = reversed(range(expr.rank))
indices = indices[::-1]
elif strides == 'F':
traversal = range(expr.rank)
for i in traversal:
temp += (shift,)
shift *= dims[i]
strides = temp
flat_index = sum(x[0]*x[1] for x in zip(indices, strides)) + offset
return "%s[%s]" % (self._print(expr.base.label),
self._print(flat_index))
def _print_Idx(self, expr):
return self._print(expr.label)
@_as_macro_if_defined
def _print_NumberSymbol(self, expr):
return super()._print_NumberSymbol(expr)
def _print_Infinity(self, expr):
return 'HUGE_VAL'
def _print_NegativeInfinity(self, expr):
return '-HUGE_VAL'
def _print_Piecewise(self, expr):
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
lines = []
if expr.has(Assignment):
for i, (e, c) in enumerate(expr.args):
if i == 0:
lines.append("if (%s) {" % self._print(c))
elif i == len(expr.args) - 1 and c == True:
lines.append("else {")
else:
lines.append("else if (%s) {" % self._print(c))
code0 = self._print(e)
lines.append(code0)
lines.append("}")
return "\n".join(lines)
else:
# The piecewise was used in an expression, need to do inline
# operators. This has the downside that inline operators will
# not work for statements that span multiple lines (Matrix or
# Indexed expressions).
ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c),
self._print(e))
for e, c in expr.args[:-1]]
last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr)
return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)])
def _print_ITE(self, expr):
from sympy.functions import Piecewise
return self._print(expr.rewrite(Piecewise, deep=False))
def _print_MatrixElement(self, expr):
return "{}[{}]".format(self.parenthesize(expr.parent, PRECEDENCE["Atom"],
strict=True), expr.j + expr.i*expr.parent.shape[1])
def _print_Symbol(self, expr):
name = super()._print_Symbol(expr)
if expr in self._settings['dereference']:
return '(*{})'.format(name)
else:
return name
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
return "{} {} {}".format(lhs_code, op, rhs_code)
def _print_For(self, expr):
target = self._print(expr.target)
if isinstance(expr.iterable, Range):
start, stop, step = expr.iterable.args
else:
raise NotImplementedError("Only iterable currently supported is Range")
body = self._print(expr.body)
return ('for ({target} = {start}; {target} < {stop}; {target} += '
'{step}) {{\n{body}\n}}').format(target=target, start=start,
stop=stop, step=step, body=body)
def _print_sign(self, func):
return '((({0}) > 0) - (({0}) < 0))'.format(self._print(func.args[0]))
def _print_Max(self, expr):
if "Max" in self.known_functions:
return self._print_Function(expr)
def inner_print_max(args): # The more natural abstraction of creating
if len(args) == 1: # and printing smaller Max objects is slow
return self._print(args[0]) # when there are many arguments.
half = len(args) // 2
return "((%(a)s > %(b)s) ? %(a)s : %(b)s)" % {
'a': inner_print_max(args[:half]),
'b': inner_print_max(args[half:])
}
return inner_print_max(expr.args)
def _print_Min(self, expr):
if "Min" in self.known_functions:
return self._print_Function(expr)
def inner_print_min(args): # The more natural abstraction of creating
if len(args) == 1: # and printing smaller Min objects is slow
return self._print(args[0]) # when there are many arguments.
half = len(args) // 2
return "((%(a)s < %(b)s) ? %(a)s : %(b)s)" % {
'a': inner_print_min(args[:half]),
'b': inner_print_min(args[half:])
}
return inner_print_min(expr.args)
def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
if isinstance(code, str):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
tab = " "
inc_token = ('{', '(', '{\n', '(\n')
dec_token = ('}', ')')
code = [line.lstrip(' \t') for line in code]
increase = [int(any(map(line.endswith, inc_token))) for line in code]
decrease = [int(any(map(line.startswith, dec_token))) for line in code]
pretty = []
level = 0
for n, line in enumerate(code):
if line in ('', '\n'):
pretty.append(line)
continue
level -= decrease[n]
pretty.append("%s%s" % (tab*level, line))
level += increase[n]
return pretty
def _get_func_suffix(self, type_):
return self.type_func_suffixes[self.type_aliases.get(type_, type_)]
def _get_literal_suffix(self, type_):
return self.type_literal_suffixes[self.type_aliases.get(type_, type_)]
def _get_math_macro_suffix(self, type_):
alias = self.type_aliases.get(type_, type_)
dflt = self.type_math_macro_suffixes.get(alias, '')
return self.type_math_macro_suffixes.get(type_, dflt)
def _print_Tuple(self, expr):
return '{'+', '.join(self._print(e) for e in expr)+'}'
_print_List = _print_Tuple
def _print_Type(self, type_):
self.headers.update(self.type_headers.get(type_, set()))
self.macros.update(self.type_macros.get(type_, set()))
return self._print(self.type_mappings.get(type_, type_.name))
def _print_Declaration(self, decl):
from sympy.codegen.cnodes import restrict
var = decl.variable
val = var.value
if var.type == untyped:
raise ValueError("C does not support untyped variables")
if isinstance(var, Pointer):
result = '{vc}{t} *{pc} {r}{s}'.format(
vc='const ' if value_const in var.attrs else '',
t=self._print(var.type),
pc=' const' if pointer_const in var.attrs else '',
r='restrict ' if restrict in var.attrs else '',
s=self._print(var.symbol)
)
elif isinstance(var, Variable):
result = '{vc}{t} {s}'.format(
vc='const ' if value_const in var.attrs else '',
t=self._print(var.type),
s=self._print(var.symbol)
)
else:
raise NotImplementedError("Unknown type of var: %s" % type(var))
if val != None: # Must be "!= None", cannot be "is not None"
result += ' = %s' % self._print(val)
return result
def _print_Float(self, flt):
type_ = self.type_aliases.get(real, real)
self.macros.update(self.type_macros.get(type_, set()))
suffix = self._get_literal_suffix(type_)
num = str(flt.evalf(type_.decimal_dig))
if 'e' not in num and '.' not in num:
num += '.0'
num_parts = num.split('e')
num_parts[0] = num_parts[0].rstrip('0')
if num_parts[0].endswith('.'):
num_parts[0] += '0'
return 'e'.join(num_parts) + suffix
@requires(headers={'stdbool.h'})
def _print_BooleanTrue(self, expr):
return 'true'
@requires(headers={'stdbool.h'})
def _print_BooleanFalse(self, expr):
return 'false'
def _print_Element(self, elem):
if elem.strides == None: # Must be "== None", cannot be "is None"
if elem.offset != None: # Must be "!= None", cannot be "is not None"
raise ValueError("Expected strides when offset is given")
idxs = ']['.join((self._print(arg) for arg in elem.indices))
else:
global_idx = sum(i*s for i, s in zip(elem.indices, elem.strides))
if elem.offset != None: # Must be "!= None", cannot be "is not None"
global_idx += elem.offset
idxs = self._print(global_idx)
return "{symb}[{idxs}]".format(
symb=self._print(elem.symbol),
idxs=idxs
)
def _print_CodeBlock(self, expr):
""" Elements of code blocks printed as statements. """
return '\n'.join([self._get_statement(self._print(i)) for i in expr.args])
def _print_While(self, expr):
return 'while ({condition}) {{\n{body}\n}}'.format(**expr.kwargs(
apply=lambda arg: self._print(arg)))
def _print_Scope(self, expr):
return '{\n%s\n}' % self._print_CodeBlock(expr.body)
@requires(headers={'stdio.h'})
def _print_Print(self, expr):
if expr.file == none:
template = 'printf({fmt}, {pargs})'
else:
template = 'fprintf(%(out)s, {fmt}, {pargs})' % {
'out': self._print(expr.file)
}
return template.format(
fmt="%s\n" if expr.format_string == none else self._print(expr.format_string),
pargs=', '.join((self._print(arg) for arg in expr.print_args))
)
def _print_Stream(self, strm):
return strm.name
def _print_FunctionPrototype(self, expr):
pars = ', '.join((self._print(Declaration(arg)) for arg in expr.parameters))
return "%s %s(%s)" % (
tuple((self._print(arg) for arg in (expr.return_type, expr.name))) + (pars,)
)
def _print_FunctionDefinition(self, expr):
return "%s%s" % (self._print_FunctionPrototype(expr),
self._print_Scope(expr))
def _print_Return(self, expr):
arg, = expr.args
return 'return %s' % self._print(arg)
def _print_CommaOperator(self, expr):
return '(%s)' % ', '.join((self._print(arg) for arg in expr.args))
def _print_Label(self, expr):
if expr.body == none:
return '%s:' % str(expr.name)
if len(expr.body.args) == 1:
return '%s:\n%s' % (str(expr.name), self._print_CodeBlock(expr.body))
return '%s:\n{\n%s\n}' % (str(expr.name), self._print_CodeBlock(expr.body))
def _print_goto(self, expr):
return 'goto %s' % expr.label.name
def _print_PreIncrement(self, expr):
arg, = expr.args
return '++(%s)' % self._print(arg)
def _print_PostIncrement(self, expr):
arg, = expr.args
return '(%s)++' % self._print(arg)
def _print_PreDecrement(self, expr):
arg, = expr.args
return '--(%s)' % self._print(arg)
def _print_PostDecrement(self, expr):
arg, = expr.args
return '(%s)--' % self._print(arg)
def _print_struct(self, expr):
return "%(keyword)s %(name)s {\n%(lines)s}" % {
"keyword": expr.__class__.__name__, "name": expr.name, "lines": ';\n'.join(
[self._print(decl) for decl in expr.declarations] + [''])
}
def _print_BreakToken(self, _):
return 'break'
def _print_ContinueToken(self, _):
return 'continue'
_print_union = _print_struct
class C99CodePrinter(C89CodePrinter):
standard = 'C99'
reserved_words = set(reserved_words + reserved_words_c99)
type_mappings=dict(chain(C89CodePrinter.type_mappings.items(), {
complex64: 'float complex',
complex128: 'double complex',
}.items()))
type_headers = dict(chain(C89CodePrinter.type_headers.items(), {
complex64: {'complex.h'},
complex128: {'complex.h'}
}.items()))
# known_functions-dict to copy
_kf: dict[str, Any] = known_functions_C99
# functions with versions with 'f' and 'l' suffixes:
_prec_funcs = ('fabs fmod remainder remquo fma fmax fmin fdim nan exp exp2'
' expm1 log log10 log2 log1p pow sqrt cbrt hypot sin cos tan'
' asin acos atan atan2 sinh cosh tanh asinh acosh atanh erf'
' erfc tgamma lgamma ceil floor trunc round nearbyint rint'
' frexp ldexp modf scalbn ilogb logb nextafter copysign').split()
def _print_Infinity(self, expr):
return 'INFINITY'
def _print_NegativeInfinity(self, expr):
return '-INFINITY'
def _print_NaN(self, expr):
return 'NAN'
# tgamma was already covered by 'known_functions' dict
@requires(headers={'math.h'}, libraries={'m'})
@_as_macro_if_defined
def _print_math_func(self, expr, nest=False, known=None):
if known is None:
known = self.known_functions[expr.__class__.__name__]
if not isinstance(known, str):
for cb, name in known:
if cb(*expr.args):
known = name
break
else:
raise ValueError("No matching printer")
try:
return known(self, *expr.args)
except TypeError:
suffix = self._get_func_suffix(real) if self._ns + known in self._prec_funcs else ''
if nest:
args = self._print(expr.args[0])
if len(expr.args) > 1:
paren_pile = ''
for curr_arg in expr.args[1:-1]:
paren_pile += ')'
args += ', {ns}{name}{suffix}({next}'.format(
ns=self._ns,
name=known,
suffix=suffix,
next = self._print(curr_arg)
)
args += ', %s%s' % (
self._print(expr.func(expr.args[-1])),
paren_pile
)
else:
args = ', '.join((self._print(arg) for arg in expr.args))
return '{ns}{name}{suffix}({args})'.format(
ns=self._ns,
name=known,
suffix=suffix,
args=args
)
def _print_Max(self, expr):
return self._print_math_func(expr, nest=True)
def _print_Min(self, expr):
return self._print_math_func(expr, nest=True)
def _get_loop_opening_ending(self, indices):
open_lines = []
close_lines = []
loopstart = "for (int %(var)s=%(start)s; %(var)s<%(end)s; %(var)s++){" # C99
for i in indices:
# C arrays start at 0 and end at dimension-1
open_lines.append(loopstart % {
'var': self._print(i.label),
'start': self._print(i.lower),
'end': self._print(i.upper + 1)})
close_lines.append("}")
return open_lines, close_lines
for k in ('Abs Sqrt exp exp2 expm1 log log10 log2 log1p Cbrt hypot fma'
' loggamma sin cos tan asin acos atan atan2 sinh cosh tanh asinh acosh '
'atanh erf erfc loggamma gamma ceiling floor').split():
setattr(C99CodePrinter, '_print_%s' % k, C99CodePrinter._print_math_func)
class C11CodePrinter(C99CodePrinter):
@requires(headers={'stdalign.h'})
def _print_alignof(self, expr):
arg, = expr.args
return 'alignof(%s)' % self._print(arg)
c_code_printers = {
'c89': C89CodePrinter,
'c99': C99CodePrinter,
'c11': C11CodePrinter
}

View File

@ -0,0 +1,888 @@
from __future__ import annotations
from typing import Any
from functools import wraps
from sympy.core import Add, Mul, Pow, S, sympify, Float
from sympy.core.basic import Basic
from sympy.core.expr import UnevaluatedExpr
from sympy.core.function import Lambda
from sympy.core.mul import _keep_coeff
from sympy.core.sorting import default_sort_key
from sympy.core.symbol import Symbol
from sympy.functions.elementary.complexes import re
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence, PRECEDENCE
class requires:
""" Decorator for registering requirements on print methods. """
def __init__(self, **kwargs):
self._req = kwargs
def __call__(self, method):
def _method_wrapper(self_, *args, **kwargs):
for k, v in self._req.items():
getattr(self_, k).update(v)
return method(self_, *args, **kwargs)
return wraps(method)(_method_wrapper)
class AssignmentError(Exception):
"""
Raised if an assignment variable for a loop is missing.
"""
pass
class PrintMethodNotImplementedError(NotImplementedError):
"""
Raised if a _print_* method is missing in the Printer.
"""
pass
def _convert_python_lists(arg):
if isinstance(arg, list):
from sympy.codegen.abstract_nodes import List
return List(*(_convert_python_lists(e) for e in arg))
elif isinstance(arg, tuple):
return tuple(_convert_python_lists(e) for e in arg)
else:
return arg
class CodePrinter(StrPrinter):
"""
The base class for code-printing subclasses.
"""
_operators = {
'and': '&&',
'or': '||',
'not': '!',
}
_default_settings: dict[str, Any] = {
'order': None,
'full_prec': 'auto',
'error_on_reserved': False,
'reserved_word_suffix': '_',
'human': True,
'inline': False,
'allow_unknown_functions': False,
'strict': None # True or False; None => True if human == True
}
# Functions which are "simple" to rewrite to other functions that
# may be supported
# function_to_rewrite : (function_to_rewrite_to, iterable_with_other_functions_required)
_rewriteable_functions = {
'cot': ('tan', []),
'csc': ('sin', []),
'sec': ('cos', []),
'acot': ('atan', []),
'acsc': ('asin', []),
'asec': ('acos', []),
'coth': ('exp', []),
'csch': ('exp', []),
'sech': ('exp', []),
'acoth': ('log', []),
'acsch': ('log', []),
'asech': ('log', []),
'catalan': ('gamma', []),
'fibonacci': ('sqrt', []),
'lucas': ('sqrt', []),
'beta': ('gamma', []),
'sinc': ('sin', ['Piecewise']),
'Mod': ('floor', []),
'factorial': ('gamma', []),
'factorial2': ('gamma', ['Piecewise']),
'subfactorial': ('uppergamma', []),
'RisingFactorial': ('gamma', ['Piecewise']),
'FallingFactorial': ('gamma', ['Piecewise']),
'binomial': ('gamma', []),
'frac': ('floor', []),
'Max': ('Piecewise', []),
'Min': ('Piecewise', []),
'Heaviside': ('Piecewise', []),
'erf2': ('erf', []),
'erfc': ('erf', []),
'Li': ('li', []),
'Ei': ('li', []),
'dirichlet_eta': ('zeta', []),
'riemann_xi': ('zeta', ['gamma']),
'SingularityFunction': ('Piecewise', []),
}
def __init__(self, settings=None):
super().__init__(settings=settings)
if self._settings.get('strict', True) == None:
# for backwards compatibility, human=False need not to throw:
self._settings['strict'] = self._settings.get('human', True) == True
if not hasattr(self, 'reserved_words'):
self.reserved_words = set()
def _handle_UnevaluatedExpr(self, expr):
return expr.replace(re, lambda arg: arg if isinstance(
arg, UnevaluatedExpr) and arg.args[0].is_real else re(arg))
def doprint(self, expr, assign_to=None):
"""
Print the expression as code.
Parameters
----------
expr : Expression
The expression to be printed.
assign_to : Symbol, string, MatrixSymbol, list of strings or Symbols (optional)
If provided, the printed code will set the expression to a variable or multiple variables
with the name or names given in ``assign_to``.
"""
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.codegen.ast import CodeBlock, Assignment
def _handle_assign_to(expr, assign_to):
if assign_to is None:
return sympify(expr)
if isinstance(assign_to, (list, tuple)):
if len(expr) != len(assign_to):
raise ValueError('Failed to assign an expression of length {} to {} variables'.format(len(expr), len(assign_to)))
return CodeBlock(*[_handle_assign_to(lhs, rhs) for lhs, rhs in zip(expr, assign_to)])
if isinstance(assign_to, str):
if expr.is_Matrix:
assign_to = MatrixSymbol(assign_to, *expr.shape)
else:
assign_to = Symbol(assign_to)
elif not isinstance(assign_to, Basic):
raise TypeError("{} cannot assign to object of type {}".format(
type(self).__name__, type(assign_to)))
return Assignment(assign_to, expr)
expr = _convert_python_lists(expr)
expr = _handle_assign_to(expr, assign_to)
# Remove re(...) nodes due to UnevaluatedExpr.is_real always is None:
expr = self._handle_UnevaluatedExpr(expr)
# keep a set of expressions that are not strictly translatable to Code
# and number constants that must be declared and initialized
self._not_supported = set()
self._number_symbols = set()
lines = self._print(expr).splitlines()
# format the output
if self._settings["human"]:
frontlines = []
if self._not_supported:
frontlines.append(self._get_comment(
"Not supported in {}:".format(self.language)))
for expr in sorted(self._not_supported, key=str):
frontlines.append(self._get_comment(type(expr).__name__))
for name, value in sorted(self._number_symbols, key=str):
frontlines.append(self._declare_number_const(name, value))
lines = frontlines + lines
lines = self._format_code(lines)
result = "\n".join(lines)
else:
lines = self._format_code(lines)
num_syms = {(k, self._print(v)) for k, v in self._number_symbols}
result = (num_syms, self._not_supported, "\n".join(lines))
self._not_supported = set()
self._number_symbols = set()
return result
def _doprint_loops(self, expr, assign_to=None):
# Here we print an expression that contains Indexed objects, they
# correspond to arrays in the generated code. The low-level implementation
# involves looping over array elements and possibly storing results in temporary
# variables or accumulate it in the assign_to object.
if self._settings.get('contract', True):
from sympy.tensor import get_contraction_structure
# Setup loops over non-dummy indices -- all terms need these
indices = self._get_expression_indices(expr, assign_to)
# Setup loops over dummy indices -- each term needs separate treatment
dummies = get_contraction_structure(expr)
else:
indices = []
dummies = {None: (expr,)}
openloop, closeloop = self._get_loop_opening_ending(indices)
# terms with no summations first
if None in dummies:
text = StrPrinter.doprint(self, Add(*dummies[None]))
else:
# If all terms have summations we must initialize array to Zero
text = StrPrinter.doprint(self, 0)
# skip redundant assignments (where lhs == rhs)
lhs_printed = self._print(assign_to)
lines = []
if text != lhs_printed:
lines.extend(openloop)
if assign_to is not None:
text = self._get_statement("%s = %s" % (lhs_printed, text))
lines.append(text)
lines.extend(closeloop)
# then terms with summations
for d in dummies:
if isinstance(d, tuple):
indices = self._sort_optimized(d, expr)
openloop_d, closeloop_d = self._get_loop_opening_ending(
indices)
for term in dummies[d]:
if term in dummies and not ([list(f.keys()) for f in dummies[term]]
== [[None] for f in dummies[term]]):
# If one factor in the term has it's own internal
# contractions, those must be computed first.
# (temporary variables?)
raise NotImplementedError(
"FIXME: no support for contractions in factor yet")
else:
# We need the lhs expression as an accumulator for
# the loops, i.e
#
# for (int d=0; d < dim; d++){
# lhs[] = lhs[] + term[][d]
# } ^.................. the accumulator
#
# We check if the expression already contains the
# lhs, and raise an exception if it does, as that
# syntax is currently undefined. FIXME: What would be
# a good interpretation?
if assign_to is None:
raise AssignmentError(
"need assignment variable for loops")
if term.has(assign_to):
raise ValueError("FIXME: lhs present in rhs,\
this is undefined in CodePrinter")
lines.extend(openloop)
lines.extend(openloop_d)
text = "%s = %s" % (lhs_printed, StrPrinter.doprint(
self, assign_to + term))
lines.append(self._get_statement(text))
lines.extend(closeloop_d)
lines.extend(closeloop)
return "\n".join(lines)
def _get_expression_indices(self, expr, assign_to):
from sympy.tensor import get_indices
rinds, junk = get_indices(expr)
linds, junk = get_indices(assign_to)
# support broadcast of scalar
if linds and not rinds:
rinds = linds
if rinds != linds:
raise ValueError("lhs indices must match non-dummy"
" rhs indices in %s" % expr)
return self._sort_optimized(rinds, assign_to)
def _sort_optimized(self, indices, expr):
from sympy.tensor.indexed import Indexed
if not indices:
return []
# determine optimized loop order by giving a score to each index
# the index with the highest score are put in the innermost loop.
score_table = {}
for i in indices:
score_table[i] = 0
arrays = expr.atoms(Indexed)
for arr in arrays:
for p, ind in enumerate(arr.indices):
try:
score_table[ind] += self._rate_index_position(p)
except KeyError:
pass
return sorted(indices, key=lambda x: score_table[x])
def _rate_index_position(self, p):
"""function to calculate score based on position among indices
This method is used to sort loops in an optimized order, see
CodePrinter._sort_optimized()
"""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _get_statement(self, codestring):
"""Formats a codestring with the proper line ending."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _get_comment(self, text):
"""Formats a text string as a comment."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _declare_number_const(self, name, value):
"""Declare a numeric constant at the top of a function"""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _format_code(self, lines):
"""Take in a list of lines of code, and format them accordingly.
This may include indenting, wrapping long lines, etc..."""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _get_loop_opening_ending(self, indices):
"""Returns a tuple (open_lines, close_lines) containing lists
of codelines"""
raise NotImplementedError("This function must be implemented by "
"subclass of CodePrinter.")
def _print_Dummy(self, expr):
if expr.name.startswith('Dummy_'):
return '_' + expr.name
else:
return '%s_%d' % (expr.name, expr.dummy_index)
def _print_CodeBlock(self, expr):
return '\n'.join([self._print(i) for i in expr.args])
def _print_String(self, string):
return str(string)
def _print_QuotedString(self, arg):
return '"%s"' % arg.text
def _print_Comment(self, string):
return self._get_comment(str(string))
def _print_Assignment(self, expr):
from sympy.codegen.ast import Assignment
from sympy.functions.elementary.piecewise import Piecewise
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.tensor.indexed import IndexedBase
lhs = expr.lhs
rhs = expr.rhs
# We special case assignments that take multiple lines
if isinstance(expr.rhs, Piecewise):
# Here we modify Piecewise so each expression is now
# an Assignment, and then continue on the print.
expressions = []
conditions = []
for (e, c) in rhs.args:
expressions.append(Assignment(lhs, e))
conditions.append(c)
temp = Piecewise(*zip(expressions, conditions))
return self._print(temp)
elif isinstance(lhs, MatrixSymbol):
# Here we form an Assignment for each element in the array,
# printing each one.
lines = []
for (i, j) in self._traverse_matrix_indices(lhs):
temp = Assignment(lhs[i, j], rhs[i, j])
code0 = self._print(temp)
lines.append(code0)
return "\n".join(lines)
elif self._settings.get("contract", False) and (lhs.has(IndexedBase) or
rhs.has(IndexedBase)):
# Here we check if there is looping to be done, and if so
# print the required loops.
return self._doprint_loops(rhs, lhs)
else:
lhs_code = self._print(lhs)
rhs_code = self._print(rhs)
return self._get_statement("%s = %s" % (lhs_code, rhs_code))
def _print_AugmentedAssignment(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
return self._get_statement("{} {} {}".format(
*(self._print(arg) for arg in [lhs_code, expr.op, rhs_code])))
def _print_FunctionCall(self, expr):
return '%s(%s)' % (
expr.name,
', '.join((self._print(arg) for arg in expr.function_args)))
def _print_Variable(self, expr):
return self._print(expr.symbol)
def _print_Symbol(self, expr):
name = super()._print_Symbol(expr)
if name in self.reserved_words:
if self._settings['error_on_reserved']:
msg = ('This expression includes the symbol "{}" which is a '
'reserved keyword in this language.')
raise ValueError(msg.format(name))
return name + self._settings['reserved_word_suffix']
else:
return name
def _can_print(self, name):
""" Check if function ``name`` is either a known function or has its own
printing method. Used to check if rewriting is possible."""
return name in self.known_functions or getattr(self, '_print_{}'.format(name), False)
def _print_Function(self, expr):
if expr.func.__name__ in self.known_functions:
cond_func = self.known_functions[expr.func.__name__]
if isinstance(cond_func, str):
return "%s(%s)" % (cond_func, self.stringify(expr.args, ", "))
else:
for cond, func in cond_func:
if cond(*expr.args):
break
if func is not None:
try:
return func(*[self.parenthesize(item, 0) for item in expr.args])
except TypeError:
return "%s(%s)" % (func, self.stringify(expr.args, ", "))
elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda):
# inlined function
return self._print(expr._imp_(*expr.args))
elif expr.func.__name__ in self._rewriteable_functions:
# Simple rewrite to supported function possible
target_f, required_fs = self._rewriteable_functions[expr.func.__name__]
if self._can_print(target_f) and all(self._can_print(f) for f in required_fs):
return '(' + self._print(expr.rewrite(target_f)) + ')'
if expr.is_Function and self._settings.get('allow_unknown_functions', False):
return '%s(%s)' % (self._print(expr.func), ', '.join(map(self._print, expr.args)))
else:
return self._print_not_supported(expr)
_print_Expr = _print_Function
# Don't inherit the str-printer method for Heaviside to the code printers
_print_Heaviside = None
def _print_NumberSymbol(self, expr):
if self._settings.get("inline", False):
return self._print(Float(expr.evalf(self._settings["precision"])))
else:
# A Number symbol that is not implemented here or with _printmethod
# is registered and evaluated
self._number_symbols.add((expr,
Float(expr.evalf(self._settings["precision"]))))
return str(expr)
def _print_Catalan(self, expr):
return self._print_NumberSymbol(expr)
def _print_EulerGamma(self, expr):
return self._print_NumberSymbol(expr)
def _print_GoldenRatio(self, expr):
return self._print_NumberSymbol(expr)
def _print_TribonacciConstant(self, expr):
return self._print_NumberSymbol(expr)
def _print_Exp1(self, expr):
return self._print_NumberSymbol(expr)
def _print_Pi(self, expr):
return self._print_NumberSymbol(expr)
def _print_And(self, expr):
PREC = precedence(expr)
return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC)
for a in sorted(expr.args, key=default_sort_key))
def _print_Or(self, expr):
PREC = precedence(expr)
return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC)
for a in sorted(expr.args, key=default_sort_key))
def _print_Xor(self, expr):
if self._operators.get('xor') is None:
return self._print(expr.to_nnf())
PREC = precedence(expr)
return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC)
for a in expr.args)
def _print_Equivalent(self, expr):
if self._operators.get('equivalent') is None:
return self._print(expr.to_nnf())
PREC = precedence(expr)
return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC)
for a in expr.args)
def _print_Not(self, expr):
PREC = precedence(expr)
return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
def _print_BooleanFunction(self, expr):
return self._print(expr.to_nnf())
def _print_Mul(self, expr):
prec = precedence(expr)
c, e = expr.as_coeff_Mul()
if c < 0:
expr = _keep_coeff(-c, e)
sign = "-"
else:
sign = ""
a = [] # items in the numerator
b = [] # items that are in the denominator (if any)
pow_paren = [] # Will collect all pow with more than one base element and exp = -1
if self.order not in ('old', 'none'):
args = expr.as_ordered_factors()
else:
# use make_args in case expr was something like -x -> x
args = Mul.make_args(expr)
# Gather args for numerator/denominator
for item in args:
if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
if item.exp != -1:
b.append(Pow(item.base, -item.exp, evaluate=False))
else:
if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160
pow_paren.append(item)
b.append(Pow(item.base, -item.exp))
else:
a.append(item)
a = a or [S.One]
if len(a) == 1 and sign == "-":
# Unary minus does not have a SymPy class, and hence there's no
# precedence weight associated with it, Python's unary minus has
# an operator precedence between multiplication and exponentiation,
# so we use this to compute a weight.
a_str = [self.parenthesize(a[0], 0.5*(PRECEDENCE["Pow"]+PRECEDENCE["Mul"]))]
else:
a_str = [self.parenthesize(x, prec) for x in a]
b_str = [self.parenthesize(x, prec) for x in b]
# To parenthesize Pow with exp = -1 and having more than one Symbol
for item in pow_paren:
if item.base in b:
b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)]
if not b:
return sign + '*'.join(a_str)
elif len(b) == 1:
return sign + '*'.join(a_str) + "/" + b_str[0]
else:
return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str)
def _print_not_supported(self, expr):
if self._settings.get('strict', False):
raise PrintMethodNotImplementedError("Unsupported by %s: %s" % (str(type(self)), str(type(expr))) + \
"\nSet the printer option 'strict' to False in order to generate partially printed code.")
try:
self._not_supported.add(expr)
except TypeError:
# not hashable
pass
return self.emptyPrinter(expr)
# The following can not be simply translated into C or Fortran
_print_Basic = _print_not_supported
_print_ComplexInfinity = _print_not_supported
_print_Derivative = _print_not_supported
_print_ExprCondPair = _print_not_supported
_print_GeometryEntity = _print_not_supported
_print_Infinity = _print_not_supported
_print_Integral = _print_not_supported
_print_Interval = _print_not_supported
_print_AccumulationBounds = _print_not_supported
_print_Limit = _print_not_supported
_print_MatrixBase = _print_not_supported
_print_DeferredVector = _print_not_supported
_print_NaN = _print_not_supported
_print_NegativeInfinity = _print_not_supported
_print_Order = _print_not_supported
_print_RootOf = _print_not_supported
_print_RootsOf = _print_not_supported
_print_RootSum = _print_not_supported
_print_Uniform = _print_not_supported
_print_Unit = _print_not_supported
_print_Wild = _print_not_supported
_print_WildFunction = _print_not_supported
_print_Relational = _print_not_supported
# Code printer functions. These are included in this file so that they can be
# imported in the top-level __init__.py without importing the sympy.codegen
# module.
def ccode(expr, assign_to=None, standard='c99', **settings):
"""Converts an expr to a string of c code
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
line-wrapping, or for expressions that generate multi-line statements.
standard : str, optional
String specifying the standard. If your compiler supports a more modern
standard you may set this to 'c99' to allow the printer to use more math
functions. [default='c89'].
precision : integer, optional
The precision for numbers such as pi [default=17].
user_functions : dict, optional
A dictionary where the keys are string representations of either
``FunctionClass`` or ``UndefinedFunction`` instances and the values
are their desired C string representations. Alternatively, the
dictionary value can be a list of tuples i.e. [(argument_test,
cfunction_string)] or [(argument_test, cfunction_formater)]. See below
for examples.
dereference : iterable, optional
An iterable of symbols that should be dereferenced in the printed code
expression. These would be values passed by address to the function.
For example, if ``dereference=[a]``, the resulting code would print
``(*a)`` instead of ``a``.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
Examples
========
>>> from sympy import ccode, symbols, Rational, sin, ceiling, Abs, Function
>>> x, tau = symbols("x, tau")
>>> expr = (2*tau)**Rational(7, 2)
>>> ccode(expr)
'8*M_SQRT2*pow(tau, 7.0/2.0)'
>>> ccode(expr, math_macros={})
'8*sqrt(2)*pow(tau, 7.0/2.0)'
>>> ccode(sin(x), assign_to="s")
's = sin(x);'
>>> from sympy.codegen.ast import real, float80
>>> ccode(expr, type_aliases={real: float80})
'8*M_SQRT2l*powl(tau, 7.0L/2.0L)'
Simple custom printing can be defined for certain types by passing a
dictionary of {"type" : "function"} to the ``user_functions`` kwarg.
Alternatively, the dictionary value can be a list of tuples i.e.
[(argument_test, cfunction_string)].
>>> custom_functions = {
... "ceiling": "CEIL",
... "Abs": [(lambda x: not x.is_integer, "fabs"),
... (lambda x: x.is_integer, "ABS")],
... "func": "f"
... }
>>> func = Function('func')
>>> ccode(func(Abs(x) + ceiling(x)), standard='C89', user_functions=custom_functions)
'f(fabs(x) + CEIL(x))'
or if the C-function takes a subset of the original arguments:
>>> ccode(2**x + 3**x, standard='C99', user_functions={'Pow': [
... (lambda b, e: b == 2, lambda b, e: 'exp2(%s)' % e),
... (lambda b, e: b != 2, 'pow')]})
'exp2(x) + pow(3, x)'
``Piecewise`` expressions are converted into conditionals. If an
``assign_to`` variable is provided an if statement is created, otherwise
the ternary operator is used. Note that if the ``Piecewise`` lacks a
default term, represented by ``(expr, True)`` then an error will be thrown.
This is to prevent generating an expression that may not evaluate to
anything.
>>> from sympy import Piecewise
>>> expr = Piecewise((x + 1, x > 0), (x, True))
>>> print(ccode(expr, tau, standard='C89'))
if (x > 0) {
tau = x + 1;
}
else {
tau = x;
}
Support for loops is provided through ``Indexed`` types. With
``contract=True`` these expressions will be turned into loops, whereas
``contract=False`` will just print the assignment expression that should be
looped over:
>>> from sympy import Eq, IndexedBase, Idx
>>> len_y = 5
>>> y = IndexedBase('y', shape=(len_y,))
>>> t = IndexedBase('t', shape=(len_y,))
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
>>> i = Idx('i', len_y-1)
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
>>> ccode(e.rhs, assign_to=e.lhs, contract=False, standard='C89')
'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
must be provided to ``assign_to``. Note that any expression that can be
generated normally can also exist inside a Matrix:
>>> from sympy import Matrix, MatrixSymbol
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
>>> A = MatrixSymbol('A', 3, 1)
>>> print(ccode(mat, A, standard='C89'))
A[0] = pow(x, 2);
if (x > 0) {
A[1] = x + 1;
}
else {
A[1] = x;
}
A[2] = sin(x);
"""
from sympy.printing.c import c_code_printers
return c_code_printers[standard.lower()](settings).doprint(expr, assign_to)
def print_ccode(expr, **settings):
"""Prints C representation of the given expression."""
print(ccode(expr, **settings))
def fcode(expr, assign_to=None, **settings):
"""Converts an expr to a string of fortran code
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
line-wrapping, or for expressions that generate multi-line statements.
precision : integer, optional
DEPRECATED. Use type_mappings instead. The precision for numbers such
as pi [default=17].
user_functions : dict, optional
A dictionary where keys are ``FunctionClass`` instances and values are
their string representations. Alternatively, the dictionary value can
be a list of tuples i.e. [(argument_test, cfunction_string)]. See below
for examples.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
source_format : optional
The source format can be either 'fixed' or 'free'. [default='fixed']
standard : integer, optional
The Fortran standard to be followed. This is specified as an integer.
Acceptable standards are 66, 77, 90, 95, 2003, and 2008. Default is 77.
Note that currently the only distinction internally is between
standards before 95, and those 95 and after. This may change later as
more features are added.
name_mangling : bool, optional
If True, then the variables that would become identical in
case-insensitive Fortran are mangled by appending different number
of ``_`` at the end. If False, SymPy Will not interfere with naming of
variables. [default=True]
Examples
========
>>> from sympy import fcode, symbols, Rational, sin, ceiling, floor
>>> x, tau = symbols("x, tau")
>>> fcode((2*tau)**Rational(7, 2))
' 8*sqrt(2.0d0)*tau**(7.0d0/2.0d0)'
>>> fcode(sin(x), assign_to="s")
' s = sin(x)'
Custom printing can be defined for certain types by passing a dictionary of
"type" : "function" to the ``user_functions`` kwarg. Alternatively, the
dictionary value can be a list of tuples i.e. [(argument_test,
cfunction_string)].
>>> custom_functions = {
... "ceiling": "CEIL",
... "floor": [(lambda x: not x.is_integer, "FLOOR1"),
... (lambda x: x.is_integer, "FLOOR2")]
... }
>>> fcode(floor(x) + ceiling(x), user_functions=custom_functions)
' CEIL(x) + FLOOR1(x)'
``Piecewise`` expressions are converted into conditionals. If an
``assign_to`` variable is provided an if statement is created, otherwise
the ternary operator is used. Note that if the ``Piecewise`` lacks a
default term, represented by ``(expr, True)`` then an error will be thrown.
This is to prevent generating an expression that may not evaluate to
anything.
>>> from sympy import Piecewise
>>> expr = Piecewise((x + 1, x > 0), (x, True))
>>> print(fcode(expr, tau))
if (x > 0) then
tau = x + 1
else
tau = x
end if
Support for loops is provided through ``Indexed`` types. With
``contract=True`` these expressions will be turned into loops, whereas
``contract=False`` will just print the assignment expression that should be
looped over:
>>> from sympy import Eq, IndexedBase, Idx
>>> len_y = 5
>>> y = IndexedBase('y', shape=(len_y,))
>>> t = IndexedBase('t', shape=(len_y,))
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
>>> i = Idx('i', len_y-1)
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
>>> fcode(e.rhs, assign_to=e.lhs, contract=False)
' Dy(i) = (y(i + 1) - y(i))/(t(i + 1) - t(i))'
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
must be provided to ``assign_to``. Note that any expression that can be
generated normally can also exist inside a Matrix:
>>> from sympy import Matrix, MatrixSymbol
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
>>> A = MatrixSymbol('A', 3, 1)
>>> print(fcode(mat, A))
A(1, 1) = x**2
if (x > 0) then
A(2, 1) = x + 1
else
A(2, 1) = x
end if
A(3, 1) = sin(x)
"""
from sympy.printing.fortran import FCodePrinter
return FCodePrinter(settings).doprint(expr, assign_to)
def print_fcode(expr, **settings):
"""Prints the Fortran representation of the given expression.
See fcode for the meaning of the optional arguments.
"""
print(fcode(expr, **settings))
def cxxcode(expr, assign_to=None, standard='c++11', **settings):
""" C++ equivalent of :func:`~.ccode`. """
from sympy.printing.cxx import cxx_code_printers
return cxx_code_printers[standard.lower()](settings).doprint(expr, assign_to)

View File

@ -0,0 +1,88 @@
"""
A few practical conventions common to all printers.
"""
import re
from collections.abc import Iterable
from sympy.core.function import Derivative
_name_with_digits_p = re.compile(r'^([^\W\d_]+)(\d+)$', re.U)
def split_super_sub(text):
"""Split a symbol name into a name, superscripts and subscripts
The first part of the symbol name is considered to be its actual
'name', followed by super- and subscripts. Each superscript is
preceded with a "^" character or by "__". Each subscript is preceded
by a "_" character. The three return values are the actual name, a
list with superscripts and a list with subscripts.
Examples
========
>>> from sympy.printing.conventions import split_super_sub
>>> split_super_sub('a_x^1')
('a', ['1'], ['x'])
>>> split_super_sub('var_sub1__sup_sub2')
('var', ['sup'], ['sub1', 'sub2'])
"""
if not text:
return text, [], []
pos = 0
name = None
supers = []
subs = []
while pos < len(text):
start = pos + 1
if text[pos:pos + 2] == "__":
start += 1
pos_hat = text.find("^", start)
if pos_hat < 0:
pos_hat = len(text)
pos_usc = text.find("_", start)
if pos_usc < 0:
pos_usc = len(text)
pos_next = min(pos_hat, pos_usc)
part = text[pos:pos_next]
pos = pos_next
if name is None:
name = part
elif part.startswith("^"):
supers.append(part[1:])
elif part.startswith("__"):
supers.append(part[2:])
elif part.startswith("_"):
subs.append(part[1:])
else:
raise RuntimeError("This should never happen.")
# Make a little exception when a name ends with digits, i.e. treat them
# as a subscript too.
m = _name_with_digits_p.match(name)
if m:
name, sub = m.groups()
subs.insert(0, sub)
return name, supers, subs
def requires_partial(expr):
"""Return whether a partial derivative symbol is required for printing
This requires checking how many free variables there are,
filtering out the ones that are integers. Some expressions do not have
free variables. In that case, check its variable list explicitly to
get the context of the expression.
"""
if isinstance(expr, Derivative):
return requires_partial(expr.expr)
if not isinstance(expr.free_symbols, Iterable):
return len(set(expr.variables)) > 1
return sum(not s.is_integer for s in expr.free_symbols) > 1

View File

@ -0,0 +1,181 @@
"""
C++ code printer
"""
from itertools import chain
from sympy.codegen.ast import Type, none
from .codeprinter import requires
from .c import C89CodePrinter, C99CodePrinter
# These are defined in the other file so we can avoid importing sympy.codegen
# from the top-level 'import sympy'. Export them here as well.
from sympy.printing.codeprinter import cxxcode # noqa:F401
# from https://en.cppreference.com/w/cpp/keyword
reserved = {
'C++98': [
'and', 'and_eq', 'asm', 'auto', 'bitand', 'bitor', 'bool', 'break',
'case', 'catch,', 'char', 'class', 'compl', 'const', 'const_cast',
'continue', 'default', 'delete', 'do', 'double', 'dynamic_cast',
'else', 'enum', 'explicit', 'export', 'extern', 'false', 'float',
'for', 'friend', 'goto', 'if', 'inline', 'int', 'long', 'mutable',
'namespace', 'new', 'not', 'not_eq', 'operator', 'or', 'or_eq',
'private', 'protected', 'public', 'register', 'reinterpret_cast',
'return', 'short', 'signed', 'sizeof', 'static', 'static_cast',
'struct', 'switch', 'template', 'this', 'throw', 'true', 'try',
'typedef', 'typeid', 'typename', 'union', 'unsigned', 'using',
'virtual', 'void', 'volatile', 'wchar_t', 'while', 'xor', 'xor_eq'
]
}
reserved['C++11'] = reserved['C++98'][:] + [
'alignas', 'alignof', 'char16_t', 'char32_t', 'constexpr', 'decltype',
'noexcept', 'nullptr', 'static_assert', 'thread_local'
]
reserved['C++17'] = reserved['C++11'][:]
reserved['C++17'].remove('register')
# TM TS: atomic_cancel, atomic_commit, atomic_noexcept, synchronized
# concepts TS: concept, requires
# module TS: import, module
_math_functions = {
'C++98': {
'Mod': 'fmod',
'ceiling': 'ceil',
},
'C++11': {
'gamma': 'tgamma',
},
'C++17': {
'beta': 'beta',
'Ei': 'expint',
'zeta': 'riemann_zeta',
}
}
# from https://en.cppreference.com/w/cpp/header/cmath
for k in ('Abs', 'exp', 'log', 'log10', 'sqrt', 'sin', 'cos', 'tan', # 'Pow'
'asin', 'acos', 'atan', 'atan2', 'sinh', 'cosh', 'tanh', 'floor'):
_math_functions['C++98'][k] = k.lower()
for k in ('asinh', 'acosh', 'atanh', 'erf', 'erfc'):
_math_functions['C++11'][k] = k.lower()
def _attach_print_method(cls, sympy_name, func_name):
meth_name = '_print_%s' % sympy_name
if hasattr(cls, meth_name):
raise ValueError("Edit method (or subclass) instead of overwriting.")
def _print_method(self, expr):
return '{}{}({})'.format(self._ns, func_name, ', '.join(map(self._print, expr.args)))
_print_method.__doc__ = "Prints code for %s" % k
setattr(cls, meth_name, _print_method)
def _attach_print_methods(cls, cont):
for sympy_name, cxx_name in cont[cls.standard].items():
_attach_print_method(cls, sympy_name, cxx_name)
class _CXXCodePrinterBase:
printmethod = "_cxxcode"
language = 'C++'
_ns = 'std::' # namespace
def __init__(self, settings=None):
super().__init__(settings or {})
@requires(headers={'algorithm'})
def _print_Max(self, expr):
from sympy.functions.elementary.miscellaneous import Max
if len(expr.args) == 1:
return self._print(expr.args[0])
return "%smax(%s, %s)" % (self._ns, self._print(expr.args[0]),
self._print(Max(*expr.args[1:])))
@requires(headers={'algorithm'})
def _print_Min(self, expr):
from sympy.functions.elementary.miscellaneous import Min
if len(expr.args) == 1:
return self._print(expr.args[0])
return "%smin(%s, %s)" % (self._ns, self._print(expr.args[0]),
self._print(Min(*expr.args[1:])))
def _print_using(self, expr):
if expr.alias == none:
return 'using %s' % expr.type
else:
raise ValueError("C++98 does not support type aliases")
def _print_Raise(self, rs):
arg, = rs.args
return 'throw %s' % self._print(arg)
@requires(headers={'stdexcept'})
def _print_RuntimeError_(self, re):
message, = re.args
return "%sruntime_error(%s)" % (self._ns, self._print(message))
class CXX98CodePrinter(_CXXCodePrinterBase, C89CodePrinter):
standard = 'C++98'
reserved_words = set(reserved['C++98'])
# _attach_print_methods(CXX98CodePrinter, _math_functions)
class CXX11CodePrinter(_CXXCodePrinterBase, C99CodePrinter):
standard = 'C++11'
reserved_words = set(reserved['C++11'])
type_mappings = dict(chain(
CXX98CodePrinter.type_mappings.items(),
{
Type('int8'): ('int8_t', {'cstdint'}),
Type('int16'): ('int16_t', {'cstdint'}),
Type('int32'): ('int32_t', {'cstdint'}),
Type('int64'): ('int64_t', {'cstdint'}),
Type('uint8'): ('uint8_t', {'cstdint'}),
Type('uint16'): ('uint16_t', {'cstdint'}),
Type('uint32'): ('uint32_t', {'cstdint'}),
Type('uint64'): ('uint64_t', {'cstdint'}),
Type('complex64'): ('std::complex<float>', {'complex'}),
Type('complex128'): ('std::complex<double>', {'complex'}),
Type('bool'): ('bool', None),
}.items()
))
def _print_using(self, expr):
if expr.alias == none:
return super()._print_using(expr)
else:
return 'using %(alias)s = %(type)s' % expr.kwargs(apply=self._print)
# _attach_print_methods(CXX11CodePrinter, _math_functions)
class CXX17CodePrinter(_CXXCodePrinterBase, C99CodePrinter):
standard = 'C++17'
reserved_words = set(reserved['C++17'])
_kf = dict(C99CodePrinter._kf, **_math_functions['C++17'])
def _print_beta(self, expr):
return self._print_math_func(expr)
def _print_Ei(self, expr):
return self._print_math_func(expr)
def _print_zeta(self, expr):
return self._print_math_func(expr)
# _attach_print_methods(CXX17CodePrinter, _math_functions)
cxx_code_printers = {
'c++98': CXX98CodePrinter,
'c++11': CXX11CodePrinter,
'c++17': CXX17CodePrinter
}

View File

@ -0,0 +1,5 @@
from sympy.core._print_helpers import Printable
# alias for compatibility
Printable.__module__ = __name__
DefaultPrinting = Printable

View File

@ -0,0 +1,294 @@
from sympy.core.basic import Basic
from sympy.core.expr import Expr
from sympy.core.symbol import Symbol
from sympy.core.numbers import Integer, Rational, Float
from sympy.printing.repr import srepr
__all__ = ['dotprint']
default_styles = (
(Basic, {'color': 'blue', 'shape': 'ellipse'}),
(Expr, {'color': 'black'})
)
slotClasses = (Symbol, Integer, Rational, Float)
def purestr(x, with_args=False):
"""A string that follows ```obj = type(obj)(*obj.args)``` exactly.
Parameters
==========
with_args : boolean, optional
If ``True``, there will be a second argument for the return
value, which is a tuple containing ``purestr`` applied to each
of the subnodes.
If ``False``, there will not be a second argument for the
return.
Default is ``False``
Examples
========
>>> from sympy import Float, Symbol, MatrixSymbol
>>> from sympy import Integer # noqa: F401
>>> from sympy.core.symbol import Str # noqa: F401
>>> from sympy.printing.dot import purestr
Applying ``purestr`` for basic symbolic object:
>>> code = purestr(Symbol('x'))
>>> code
"Symbol('x')"
>>> eval(code) == Symbol('x')
True
For basic numeric object:
>>> purestr(Float(2))
"Float('2.0', precision=53)"
For matrix symbol:
>>> code = purestr(MatrixSymbol('x', 2, 2))
>>> code
"MatrixSymbol(Str('x'), Integer(2), Integer(2))"
>>> eval(code) == MatrixSymbol('x', 2, 2)
True
With ``with_args=True``:
>>> purestr(Float(2), with_args=True)
("Float('2.0', precision=53)", ())
>>> purestr(MatrixSymbol('x', 2, 2), with_args=True)
("MatrixSymbol(Str('x'), Integer(2), Integer(2))",
("Str('x')", 'Integer(2)', 'Integer(2)'))
"""
sargs = ()
if not isinstance(x, Basic):
rv = str(x)
elif not x.args:
rv = srepr(x)
else:
args = x.args
sargs = tuple(map(purestr, args))
rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs))
if with_args:
rv = rv, sargs
return rv
def styleof(expr, styles=default_styles):
""" Merge style dictionaries in order
Examples
========
>>> from sympy import Symbol, Basic, Expr, S
>>> from sympy.printing.dot import styleof
>>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
... (Expr, {'color': 'black'})]
>>> styleof(Basic(S(1)), styles)
{'color': 'blue', 'shape': 'ellipse'}
>>> x = Symbol('x')
>>> styleof(x + 1, styles) # this is an Expr
{'color': 'black', 'shape': 'ellipse'}
"""
style = {}
for typ, sty in styles:
if isinstance(expr, typ):
style.update(sty)
return style
def attrprint(d, delimiter=', '):
""" Print a dictionary of attributes
Examples
========
>>> from sympy.printing.dot import attrprint
>>> print(attrprint({'color': 'blue', 'shape': 'ellipse'}))
"color"="blue", "shape"="ellipse"
"""
return delimiter.join('"%s"="%s"'%item for item in sorted(d.items()))
def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True):
""" String defining a node
Examples
========
>>> from sympy.printing.dot import dotnode
>>> from sympy.abc import x
>>> print(dotnode(x))
"Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"];
"""
style = styleof(expr, styles)
if isinstance(expr, Basic) and not expr.is_Atom:
label = str(expr.__class__.__name__)
else:
label = labelfunc(expr)
style['label'] = label
expr_str = purestr(expr)
if repeat:
expr_str += '_%s' % str(pos)
return '"%s" [%s];' % (expr_str, attrprint(style))
def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True):
""" List of strings for all expr->expr.arg pairs
See the docstring of dotprint for explanations of the options.
Examples
========
>>> from sympy.printing.dot import dotedges
>>> from sympy.abc import x
>>> for e in dotedges(x+2):
... print(e)
"Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)";
"Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)";
"""
if atom(expr):
return []
else:
expr_str, arg_strs = purestr(expr, with_args=True)
if repeat:
expr_str += '_%s' % str(pos)
arg_strs = ['%s_%s' % (a, str(pos + (i,)))
for i, a in enumerate(arg_strs)]
return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs]
template = \
"""digraph{
# Graph style
%(graphstyle)s
#########
# Nodes #
#########
%(nodes)s
#########
# Edges #
#########
%(edges)s
}"""
_graphstyle = {'rankdir': 'TD', 'ordering': 'out'}
def dotprint(expr,
styles=default_styles, atom=lambda x: not isinstance(x, Basic),
maxdepth=None, repeat=True, labelfunc=str, **kwargs):
"""DOT description of a SymPy expression tree
Parameters
==========
styles : list of lists composed of (Class, mapping), optional
Styles for different classes.
The default is
.. code-block:: python
(
(Basic, {'color': 'blue', 'shape': 'ellipse'}),
(Expr, {'color': 'black'})
)
atom : function, optional
Function used to determine if an arg is an atom.
A good choice is ``lambda x: not x.args``.
The default is ``lambda x: not isinstance(x, Basic)``.
maxdepth : integer, optional
The maximum depth.
The default is ``None``, meaning no limit.
repeat : boolean, optional
Whether to use different nodes for common subexpressions.
The default is ``True``.
For example, for ``x + x*y`` with ``repeat=True``, it will have
two nodes for ``x``; with ``repeat=False``, it will have one
node.
.. warning::
Even if a node appears twice in the same object like ``x`` in
``Pow(x, x)``, it will still only appear once.
Hence, with ``repeat=False``, the number of arrows out of an
object might not equal the number of args it has.
labelfunc : function, optional
A function to create a label for a given leaf node.
The default is ``str``.
Another good option is ``srepr``.
For example with ``str``, the leaf nodes of ``x + 1`` are labeled,
``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')``
and ``Integer(1)``.
**kwargs : optional
Additional keyword arguments are included as styles for the graph.
Examples
========
>>> from sympy import dotprint
>>> from sympy.abc import x
>>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE
digraph{
<BLANKLINE>
# Graph style
"ordering"="out"
"rankdir"="TD"
<BLANKLINE>
#########
# Nodes #
#########
<BLANKLINE>
"Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"];
"Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"];
"Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"];
<BLANKLINE>
#########
# Edges #
#########
<BLANKLINE>
"Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)";
"Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)";
}
"""
# repeat works by adding a signature tuple to the end of each node for its
# position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the
# Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0].
graphstyle = _graphstyle.copy()
graphstyle.update(kwargs)
nodes = []
edges = []
def traverse(e, depth, pos=()):
nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat))
if maxdepth and depth >= maxdepth:
return
edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat))
[traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)]
traverse(expr, 0)
return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'),
'nodes': '\n'.join(nodes),
'edges': '\n'.join(edges)}

View File

@ -0,0 +1,782 @@
"""
Fortran code printer
The FCodePrinter converts single SymPy expressions into single Fortran
expressions, using the functions defined in the Fortran 77 standard where
possible. Some useful pointers to Fortran can be found on wikipedia:
https://en.wikipedia.org/wiki/Fortran
Most of the code below is based on the "Professional Programmer\'s Guide to
Fortran77" by Clive G. Page:
https://www.star.le.ac.uk/~cgp/prof77.html
Fortran is a case-insensitive language. This might cause trouble because
SymPy is case sensitive. So, fcode adds underscores to variable names when
it is necessary to make them different for Fortran.
"""
from __future__ import annotations
from typing import Any
from collections import defaultdict
from itertools import chain
import string
from sympy.codegen.ast import (
Assignment, Declaration, Pointer, value_const,
float32, float64, float80, complex64, complex128, int8, int16, int32,
int64, intc, real, integer, bool_, complex_, none, stderr, stdout
)
from sympy.codegen.fnodes import (
allocatable, isign, dsign, cmplx, merge, literal_dp, elemental, pure,
intent_in, intent_out, intent_inout
)
from sympy.core import S, Add, N, Float, Symbol
from sympy.core.function import Function
from sympy.core.numbers import equal_valued
from sympy.core.relational import Eq
from sympy.sets import Range
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence, PRECEDENCE
from sympy.printing.printer import printer_context
# These are defined in the other file so we can avoid importing sympy.codegen
# from the top-level 'import sympy'. Export them here as well.
from sympy.printing.codeprinter import fcode, print_fcode # noqa:F401
known_functions = {
"sin": "sin",
"cos": "cos",
"tan": "tan",
"asin": "asin",
"acos": "acos",
"atan": "atan",
"atan2": "atan2",
"sinh": "sinh",
"cosh": "cosh",
"tanh": "tanh",
"log": "log",
"exp": "exp",
"erf": "erf",
"Abs": "abs",
"conjugate": "conjg",
"Max": "max",
"Min": "min",
}
class FCodePrinter(CodePrinter):
"""A printer to convert SymPy expressions to strings of Fortran code"""
printmethod = "_fcode"
language = "Fortran"
type_aliases = {
integer: int32,
real: float64,
complex_: complex128,
}
type_mappings = {
intc: 'integer(c_int)',
float32: 'real*4', # real(kind(0.e0))
float64: 'real*8', # real(kind(0.d0))
float80: 'real*10', # real(kind(????))
complex64: 'complex*8',
complex128: 'complex*16',
int8: 'integer*1',
int16: 'integer*2',
int32: 'integer*4',
int64: 'integer*8',
bool_: 'logical'
}
type_modules = {
intc: {'iso_c_binding': 'c_int'}
}
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
'precision': 17,
'user_functions': {},
'source_format': 'fixed',
'contract': True,
'standard': 77,
'name_mangling': True,
})
_operators = {
'and': '.and.',
'or': '.or.',
'xor': '.neqv.',
'equivalent': '.eqv.',
'not': '.not. ',
}
_relationals = {
'!=': '/=',
}
def __init__(self, settings=None):
if not settings:
settings = {}
self.mangled_symbols = {} # Dict showing mapping of all words
self.used_name = []
self.type_aliases = dict(chain(self.type_aliases.items(),
settings.pop('type_aliases', {}).items()))
self.type_mappings = dict(chain(self.type_mappings.items(),
settings.pop('type_mappings', {}).items()))
super().__init__(settings)
self.known_functions = dict(known_functions)
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
# leading columns depend on fixed or free format
standards = {66, 77, 90, 95, 2003, 2008}
if self._settings['standard'] not in standards:
raise ValueError("Unknown Fortran standard: %s" % self._settings[
'standard'])
self.module_uses = defaultdict(set) # e.g.: use iso_c_binding, only: c_int
@property
def _lead(self):
if self._settings['source_format'] == 'fixed':
return {'code': " ", 'cont': " @ ", 'comment': "C "}
elif self._settings['source_format'] == 'free':
return {'code': "", 'cont': " ", 'comment': "! "}
else:
raise ValueError("Unknown source format: %s" % self._settings['source_format'])
def _print_Symbol(self, expr):
if self._settings['name_mangling'] == True:
if expr not in self.mangled_symbols:
name = expr.name
while name.lower() in self.used_name:
name += '_'
self.used_name.append(name.lower())
if name == expr.name:
self.mangled_symbols[expr] = expr
else:
self.mangled_symbols[expr] = Symbol(name)
expr = expr.xreplace(self.mangled_symbols)
name = super()._print_Symbol(expr)
return name
def _rate_index_position(self, p):
return -p*5
def _get_statement(self, codestring):
return codestring
def _get_comment(self, text):
return "! {}".format(text)
def _declare_number_const(self, name, value):
return "parameter ({} = {})".format(name, self._print(value))
def _print_NumberSymbol(self, expr):
# A Number symbol that is not implemented here or with _printmethod
# is registered and evaluated
self._number_symbols.add((expr, Float(expr.evalf(self._settings['precision']))))
return str(expr)
def _format_code(self, lines):
return self._wrap_fortran(self.indent_code(lines))
def _traverse_matrix_indices(self, mat):
rows, cols = mat.shape
return ((i, j) for j in range(cols) for i in range(rows))
def _get_loop_opening_ending(self, indices):
open_lines = []
close_lines = []
for i in indices:
# fortran arrays start at 1 and end at dimension
var, start, stop = map(self._print,
[i.label, i.lower + 1, i.upper + 1])
open_lines.append("do %s = %s, %s" % (var, start, stop))
close_lines.append("end do")
return open_lines, close_lines
def _print_sign(self, expr):
from sympy.functions.elementary.complexes import Abs
arg, = expr.args
if arg.is_integer:
new_expr = merge(0, isign(1, arg), Eq(arg, 0))
elif (arg.is_complex or arg.is_infinite):
new_expr = merge(cmplx(literal_dp(0), literal_dp(0)), arg/Abs(arg), Eq(Abs(arg), literal_dp(0)))
else:
new_expr = merge(literal_dp(0), dsign(literal_dp(1), arg), Eq(arg, literal_dp(0)))
return self._print(new_expr)
def _print_Piecewise(self, expr):
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
lines = []
if expr.has(Assignment):
for i, (e, c) in enumerate(expr.args):
if i == 0:
lines.append("if (%s) then" % self._print(c))
elif i == len(expr.args) - 1 and c == True:
lines.append("else")
else:
lines.append("else if (%s) then" % self._print(c))
lines.append(self._print(e))
lines.append("end if")
return "\n".join(lines)
elif self._settings["standard"] >= 95:
# Only supported in F95 and newer:
# The piecewise was used in an expression, need to do inline
# operators. This has the downside that inline operators will
# not work for statements that span multiple lines (Matrix or
# Indexed expressions).
pattern = "merge({T}, {F}, {COND})"
code = self._print(expr.args[-1].expr)
terms = list(expr.args[:-1])
while terms:
e, c = terms.pop()
expr = self._print(e)
cond = self._print(c)
code = pattern.format(T=expr, F=code, COND=cond)
return code
else:
# `merge` is not supported prior to F95
raise NotImplementedError("Using Piecewise as an expression using "
"inline operators is not supported in "
"standards earlier than Fortran95.")
def _print_MatrixElement(self, expr):
return "{}({}, {})".format(self.parenthesize(expr.parent,
PRECEDENCE["Atom"], strict=True), expr.i + 1, expr.j + 1)
def _print_Add(self, expr):
# purpose: print complex numbers nicely in Fortran.
# collect the purely real and purely imaginary parts:
pure_real = []
pure_imaginary = []
mixed = []
for arg in expr.args:
if arg.is_number and arg.is_real:
pure_real.append(arg)
elif arg.is_number and arg.is_imaginary:
pure_imaginary.append(arg)
else:
mixed.append(arg)
if pure_imaginary:
if mixed:
PREC = precedence(expr)
term = Add(*mixed)
t = self._print(term)
if t.startswith('-'):
sign = "-"
t = t[1:]
else:
sign = "+"
if precedence(term) < PREC:
t = "(%s)" % t
return "cmplx(%s,%s) %s %s" % (
self._print(Add(*pure_real)),
self._print(-S.ImaginaryUnit*Add(*pure_imaginary)),
sign, t,
)
else:
return "cmplx(%s,%s)" % (
self._print(Add(*pure_real)),
self._print(-S.ImaginaryUnit*Add(*pure_imaginary)),
)
else:
return CodePrinter._print_Add(self, expr)
def _print_Function(self, expr):
# All constant function args are evaluated as floats
prec = self._settings['precision']
args = [N(a, prec) for a in expr.args]
eval_expr = expr.func(*args)
if not isinstance(eval_expr, Function):
return self._print(eval_expr)
else:
return CodePrinter._print_Function(self, expr.func(*args))
def _print_Mod(self, expr):
# NOTE : Fortran has the functions mod() and modulo(). modulo() behaves
# the same wrt to the sign of the arguments as Python and SymPy's
# modulus computations (% and Mod()) but is not available in Fortran 66
# or Fortran 77, thus we raise an error.
if self._settings['standard'] in [66, 77]:
msg = ("Python % operator and SymPy's Mod() function are not "
"supported by Fortran 66 or 77 standards.")
raise NotImplementedError(msg)
else:
x, y = expr.args
return " modulo({}, {})".format(self._print(x), self._print(y))
def _print_ImaginaryUnit(self, expr):
# purpose: print complex numbers nicely in Fortran.
return "cmplx(0,1)"
def _print_int(self, expr):
return str(expr)
def _print_Mul(self, expr):
# purpose: print complex numbers nicely in Fortran.
if expr.is_number and expr.is_imaginary:
return "cmplx(0,%s)" % (
self._print(-S.ImaginaryUnit*expr)
)
else:
return CodePrinter._print_Mul(self, expr)
def _print_Pow(self, expr):
PREC = precedence(expr)
if equal_valued(expr.exp, -1):
return '%s/%s' % (
self._print(literal_dp(1)),
self.parenthesize(expr.base, PREC)
)
elif equal_valued(expr.exp, 0.5):
if expr.base.is_integer:
# Fortran intrinsic sqrt() does not accept integer argument
if expr.base.is_Number:
return 'sqrt(%s.0d0)' % self._print(expr.base)
else:
return 'sqrt(dble(%s))' % self._print(expr.base)
else:
return 'sqrt(%s)' % self._print(expr.base)
else:
return CodePrinter._print_Pow(self, expr)
def _print_Rational(self, expr):
p, q = int(expr.p), int(expr.q)
return "%d.0d0/%d.0d0" % (p, q)
def _print_Float(self, expr):
printed = CodePrinter._print_Float(self, expr)
e = printed.find('e')
if e > -1:
return "%sd%s" % (printed[:e], printed[e + 1:])
return "%sd0" % printed
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
op = op if op not in self._relationals else self._relationals[op]
return "{} {} {}".format(lhs_code, op, rhs_code)
def _print_Indexed(self, expr):
inds = [ self._print(i) for i in expr.indices ]
return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds))
def _print_Idx(self, expr):
return self._print(expr.label)
def _print_AugmentedAssignment(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
return self._get_statement("{0} = {0} {1} {2}".format(
self._print(lhs_code), self._print(expr.binop), self._print(rhs_code)))
def _print_sum_(self, sm):
params = self._print(sm.array)
if sm.dim != None: # Must use '!= None', cannot use 'is not None'
params += ', ' + self._print(sm.dim)
if sm.mask != None: # Must use '!= None', cannot use 'is not None'
params += ', mask=' + self._print(sm.mask)
return '%s(%s)' % (sm.__class__.__name__.rstrip('_'), params)
def _print_product_(self, prod):
return self._print_sum_(prod)
def _print_Do(self, do):
excl = ['concurrent']
if do.step == 1:
excl.append('step')
step = ''
else:
step = ', {step}'
return (
'do {concurrent}{counter} = {first}, {last}'+step+'\n'
'{body}\n'
'end do\n'
).format(
concurrent='concurrent ' if do.concurrent else '',
**do.kwargs(apply=lambda arg: self._print(arg), exclude=excl)
)
def _print_ImpliedDoLoop(self, idl):
step = '' if idl.step == 1 else ', {step}'
return ('({expr}, {counter} = {first}, {last}'+step+')').format(
**idl.kwargs(apply=lambda arg: self._print(arg))
)
def _print_For(self, expr):
target = self._print(expr.target)
if isinstance(expr.iterable, Range):
start, stop, step = expr.iterable.args
else:
raise NotImplementedError("Only iterable currently supported is Range")
body = self._print(expr.body)
return ('do {target} = {start}, {stop}, {step}\n'
'{body}\n'
'end do').format(target=target, start=start, stop=stop - 1,
step=step, body=body)
def _print_Type(self, type_):
type_ = self.type_aliases.get(type_, type_)
type_str = self.type_mappings.get(type_, type_.name)
module_uses = self.type_modules.get(type_)
if module_uses:
for k, v in module_uses:
self.module_uses[k].add(v)
return type_str
def _print_Element(self, elem):
return '{symbol}({idxs})'.format(
symbol=self._print(elem.symbol),
idxs=', '.join((self._print(arg) for arg in elem.indices))
)
def _print_Extent(self, ext):
return str(ext)
def _print_Declaration(self, expr):
var = expr.variable
val = var.value
dim = var.attr_params('dimension')
intents = [intent in var.attrs for intent in (intent_in, intent_out, intent_inout)]
if intents.count(True) == 0:
intent = ''
elif intents.count(True) == 1:
intent = ', intent(%s)' % ['in', 'out', 'inout'][intents.index(True)]
else:
raise ValueError("Multiple intents specified for %s" % self)
if isinstance(var, Pointer):
raise NotImplementedError("Pointers are not available by default in Fortran.")
if self._settings["standard"] >= 90:
result = '{t}{vc}{dim}{intent}{alloc} :: {s}'.format(
t=self._print(var.type),
vc=', parameter' if value_const in var.attrs else '',
dim=', dimension(%s)' % ', '.join((self._print(arg) for arg in dim)) if dim else '',
intent=intent,
alloc=', allocatable' if allocatable in var.attrs else '',
s=self._print(var.symbol)
)
if val != None: # Must be "!= None", cannot be "is not None"
result += ' = %s' % self._print(val)
else:
if value_const in var.attrs or val:
raise NotImplementedError("F77 init./parameter statem. req. multiple lines.")
result = ' '.join((self._print(arg) for arg in [var.type, var.symbol]))
return result
def _print_Infinity(self, expr):
return '(huge(%s) + 1)' % self._print(literal_dp(0))
def _print_While(self, expr):
return 'do while ({condition})\n{body}\nend do'.format(**expr.kwargs(
apply=lambda arg: self._print(arg)))
def _print_BooleanTrue(self, expr):
return '.true.'
def _print_BooleanFalse(self, expr):
return '.false.'
def _pad_leading_columns(self, lines):
result = []
for line in lines:
if line.startswith('!'):
result.append(self._lead['comment'] + line[1:].lstrip())
else:
result.append(self._lead['code'] + line)
return result
def _wrap_fortran(self, lines):
"""Wrap long Fortran lines
Argument:
lines -- a list of lines (without \\n character)
A comment line is split at white space. Code lines are split with a more
complex rule to give nice results.
"""
# routine to find split point in a code line
my_alnum = set("_+-." + string.digits + string.ascii_letters)
my_white = set(" \t()")
def split_pos_code(line, endpos):
if len(line) <= endpos:
return len(line)
pos = endpos
split = lambda pos: \
(line[pos] in my_alnum and line[pos - 1] not in my_alnum) or \
(line[pos] not in my_alnum and line[pos - 1] in my_alnum) or \
(line[pos] in my_white and line[pos - 1] not in my_white) or \
(line[pos] not in my_white and line[pos - 1] in my_white)
while not split(pos):
pos -= 1
if pos == 0:
return endpos
return pos
# split line by line and add the split lines to result
result = []
if self._settings['source_format'] == 'free':
trailing = ' &'
else:
trailing = ''
for line in lines:
if line.startswith(self._lead['comment']):
# comment line
if len(line) > 72:
pos = line.rfind(" ", 6, 72)
if pos == -1:
pos = 72
hunk = line[:pos]
line = line[pos:].lstrip()
result.append(hunk)
while line:
pos = line.rfind(" ", 0, 66)
if pos == -1 or len(line) < 66:
pos = 66
hunk = line[:pos]
line = line[pos:].lstrip()
result.append("%s%s" % (self._lead['comment'], hunk))
else:
result.append(line)
elif line.startswith(self._lead['code']):
# code line
pos = split_pos_code(line, 72)
hunk = line[:pos].rstrip()
line = line[pos:].lstrip()
if line:
hunk += trailing
result.append(hunk)
while line:
pos = split_pos_code(line, 65)
hunk = line[:pos].rstrip()
line = line[pos:].lstrip()
if line:
hunk += trailing
result.append("%s%s" % (self._lead['cont'], hunk))
else:
result.append(line)
return result
def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
if isinstance(code, str):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
free = self._settings['source_format'] == 'free'
code = [ line.lstrip(' \t') for line in code ]
inc_keyword = ('do ', 'if(', 'if ', 'do\n', 'else', 'program', 'interface')
dec_keyword = ('end do', 'enddo', 'end if', 'endif', 'else', 'end program', 'end interface')
increase = [ int(any(map(line.startswith, inc_keyword)))
for line in code ]
decrease = [ int(any(map(line.startswith, dec_keyword)))
for line in code ]
continuation = [ int(any(map(line.endswith, ['&', '&\n'])))
for line in code ]
level = 0
cont_padding = 0
tabwidth = 3
new_code = []
for i, line in enumerate(code):
if line in ('', '\n'):
new_code.append(line)
continue
level -= decrease[i]
if free:
padding = " "*(level*tabwidth + cont_padding)
else:
padding = " "*level*tabwidth
line = "%s%s" % (padding, line)
if not free:
line = self._pad_leading_columns([line])[0]
new_code.append(line)
if continuation[i]:
cont_padding = 2*tabwidth
else:
cont_padding = 0
level += increase[i]
if not free:
return self._wrap_fortran(new_code)
return new_code
def _print_GoTo(self, goto):
if goto.expr: # computed goto
return "go to ({labels}), {expr}".format(
labels=', '.join((self._print(arg) for arg in goto.labels)),
expr=self._print(goto.expr)
)
else:
lbl, = goto.labels
return "go to %s" % self._print(lbl)
def _print_Program(self, prog):
return (
"program {name}\n"
"{body}\n"
"end program\n"
).format(**prog.kwargs(apply=lambda arg: self._print(arg)))
def _print_Module(self, mod):
return (
"module {name}\n"
"{declarations}\n"
"\ncontains\n\n"
"{definitions}\n"
"end module\n"
).format(**mod.kwargs(apply=lambda arg: self._print(arg)))
def _print_Stream(self, strm):
if strm.name == 'stdout' and self._settings["standard"] >= 2003:
self.module_uses['iso_c_binding'].add('stdint=>input_unit')
return 'input_unit'
elif strm.name == 'stderr' and self._settings["standard"] >= 2003:
self.module_uses['iso_c_binding'].add('stdint=>error_unit')
return 'error_unit'
else:
if strm.name == 'stdout':
return '*'
else:
return strm.name
def _print_Print(self, ps):
if ps.format_string == none: # Must be '!= None', cannot be 'is not None'
template = "print {fmt}, {iolist}"
fmt = '*'
else:
template = 'write(%(out)s, fmt="{fmt}", advance="no"), {iolist}' % {
'out': {stderr: '0', stdout: '6'}.get(ps.file, '*')
}
fmt = self._print(ps.format_string)
return template.format(fmt=fmt, iolist=', '.join(
(self._print(arg) for arg in ps.print_args)))
def _print_Return(self, rs):
arg, = rs.args
return "{result_name} = {arg}".format(
result_name=self._context.get('result_name', 'sympy_result'),
arg=self._print(arg)
)
def _print_FortranReturn(self, frs):
arg, = frs.args
if arg:
return 'return %s' % self._print(arg)
else:
return 'return'
def _head(self, entity, fp, **kwargs):
bind_C_params = fp.attr_params('bind_C')
if bind_C_params is None:
bind = ''
else:
bind = ' bind(C, name="%s")' % bind_C_params[0] if bind_C_params else ' bind(C)'
result_name = self._settings.get('result_name', None)
return (
"{entity}{name}({arg_names}){result}{bind}\n"
"{arg_declarations}"
).format(
entity=entity,
name=self._print(fp.name),
arg_names=', '.join([self._print(arg.symbol) for arg in fp.parameters]),
result=(' result(%s)' % result_name) if result_name else '',
bind=bind,
arg_declarations='\n'.join((self._print(Declaration(arg)) for arg in fp.parameters))
)
def _print_FunctionPrototype(self, fp):
entity = "{} function ".format(self._print(fp.return_type))
return (
"interface\n"
"{function_head}\n"
"end function\n"
"end interface"
).format(function_head=self._head(entity, fp))
def _print_FunctionDefinition(self, fd):
if elemental in fd.attrs:
prefix = 'elemental '
elif pure in fd.attrs:
prefix = 'pure '
else:
prefix = ''
entity = "{} function ".format(self._print(fd.return_type))
with printer_context(self, result_name=fd.name):
return (
"{prefix}{function_head}\n"
"{body}\n"
"end function\n"
).format(
prefix=prefix,
function_head=self._head(entity, fd),
body=self._print(fd.body)
)
def _print_Subroutine(self, sub):
return (
'{subroutine_head}\n'
'{body}\n'
'end subroutine\n'
).format(
subroutine_head=self._head('subroutine ', sub),
body=self._print(sub.body)
)
def _print_SubroutineCall(self, scall):
return 'call {name}({args})'.format(
name=self._print(scall.name),
args=', '.join((self._print(arg) for arg in scall.subroutine_args))
)
def _print_use_rename(self, rnm):
return "%s => %s" % tuple((self._print(arg) for arg in rnm.args))
def _print_use(self, use):
result = 'use %s' % self._print(use.namespace)
if use.rename != None: # Must be '!= None', cannot be 'is not None'
result += ', ' + ', '.join([self._print(rnm) for rnm in use.rename])
if use.only != None: # Must be '!= None', cannot be 'is not None'
result += ', only: ' + ', '.join([self._print(nly) for nly in use.only])
return result
def _print_BreakToken(self, _):
return 'exit'
def _print_ContinueToken(self, _):
return 'cycle'
def _print_ArrayConstructor(self, ac):
fmtstr = "[%s]" if self._settings["standard"] >= 2003 else '(/%s/)'
return fmtstr % ', '.join((self._print(arg) for arg in ac.elements))
def _print_ArrayElement(self, elem):
return '{symbol}({idxs})'.format(
symbol=self._print(elem.name),
idxs=', '.join((self._print(arg) for arg in elem.indices))
)

View File

@ -0,0 +1,551 @@
from __future__ import annotations
from sympy.core import Basic, S
from sympy.core.function import Lambda
from sympy.core.numbers import equal_valued
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence
from functools import reduce
known_functions = {
'Abs': 'abs',
'sin': 'sin',
'cos': 'cos',
'tan': 'tan',
'acos': 'acos',
'asin': 'asin',
'atan': 'atan',
'atan2': 'atan',
'ceiling': 'ceil',
'floor': 'floor',
'sign': 'sign',
'exp': 'exp',
'log': 'log',
'add': 'add',
'sub': 'sub',
'mul': 'mul',
'pow': 'pow'
}
class GLSLPrinter(CodePrinter):
"""
Rudimentary, generic GLSL printing tools.
Additional settings:
'use_operators': Boolean (should the printer use operators for +,-,*, or functions?)
"""
_not_supported: set[Basic] = set()
printmethod = "_glsl"
language = "GLSL"
_default_settings = dict(CodePrinter._default_settings, **{
'use_operators': True,
'zero': 0,
'mat_nested': False,
'mat_separator': ',\n',
'mat_transpose': False,
'array_type': 'float',
'glsl_types': True,
'precision': 9,
'user_functions': {},
'contract': True,
})
def __init__(self, settings={}):
CodePrinter.__init__(self, settings)
self.known_functions = dict(known_functions)
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
def _rate_index_position(self, p):
return p*5
def _get_statement(self, codestring):
return "%s;" % codestring
def _get_comment(self, text):
return "// {}".format(text)
def _declare_number_const(self, name, value):
return "float {} = {};".format(name, value)
def _format_code(self, lines):
return self.indent_code(lines)
def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
if isinstance(code, str):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
tab = " "
inc_token = ('{', '(', '{\n', '(\n')
dec_token = ('}', ')')
code = [line.lstrip(' \t') for line in code]
increase = [int(any(map(line.endswith, inc_token))) for line in code]
decrease = [int(any(map(line.startswith, dec_token))) for line in code]
pretty = []
level = 0
for n, line in enumerate(code):
if line in ('', '\n'):
pretty.append(line)
continue
level -= decrease[n]
pretty.append("%s%s" % (tab*level, line))
level += increase[n]
return pretty
def _print_MatrixBase(self, mat):
mat_separator = self._settings['mat_separator']
mat_transpose = self._settings['mat_transpose']
column_vector = (mat.rows == 1) if mat_transpose else (mat.cols == 1)
A = mat.transpose() if mat_transpose != column_vector else mat
glsl_types = self._settings['glsl_types']
array_type = self._settings['array_type']
array_size = A.cols*A.rows
array_constructor = "{}[{}]".format(array_type, array_size)
if A.cols == 1:
return self._print(A[0]);
if A.rows <= 4 and A.cols <= 4 and glsl_types:
if A.rows == 1:
return "vec{}{}".format(
A.cols, A.table(self,rowstart='(',rowend=')')
)
elif A.rows == A.cols:
return "mat{}({})".format(
A.rows, A.table(self,rowsep=', ',
rowstart='',rowend='')
)
else:
return "mat{}x{}({})".format(
A.cols, A.rows,
A.table(self,rowsep=', ',
rowstart='',rowend='')
)
elif S.One in A.shape:
return "{}({})".format(
array_constructor,
A.table(self,rowsep=mat_separator,rowstart='',rowend='')
)
elif not self._settings['mat_nested']:
return "{}(\n{}\n) /* a {}x{} matrix */".format(
array_constructor,
A.table(self,rowsep=mat_separator,rowstart='',rowend=''),
A.rows, A.cols
)
elif self._settings['mat_nested']:
return "{}[{}][{}](\n{}\n)".format(
array_type, A.rows, A.cols,
A.table(self,rowsep=mat_separator,rowstart='float[](',rowend=')')
)
def _print_SparseRepMatrix(self, mat):
# do not allow sparse matrices to be made dense
return self._print_not_supported(mat)
def _traverse_matrix_indices(self, mat):
mat_transpose = self._settings['mat_transpose']
if mat_transpose:
rows,cols = mat.shape
else:
cols,rows = mat.shape
return ((i, j) for i in range(cols) for j in range(rows))
def _print_MatrixElement(self, expr):
# print('begin _print_MatrixElement')
nest = self._settings['mat_nested'];
glsl_types = self._settings['glsl_types'];
mat_transpose = self._settings['mat_transpose'];
if mat_transpose:
cols,rows = expr.parent.shape
i,j = expr.j,expr.i
else:
rows,cols = expr.parent.shape
i,j = expr.i,expr.j
pnt = self._print(expr.parent)
if glsl_types and ((rows <= 4 and cols <=4) or nest):
return "{}[{}][{}]".format(pnt, i, j)
else:
return "{}[{}]".format(pnt, i + j*rows)
def _print_list(self, expr):
l = ', '.join(self._print(item) for item in expr)
glsl_types = self._settings['glsl_types']
array_type = self._settings['array_type']
array_size = len(expr)
array_constructor = '{}[{}]'.format(array_type, array_size)
if array_size <= 4 and glsl_types:
return 'vec{}({})'.format(array_size, l)
else:
return '{}({})'.format(array_constructor, l)
_print_tuple = _print_list
_print_Tuple = _print_list
def _get_loop_opening_ending(self, indices):
open_lines = []
close_lines = []
loopstart = "for (int %(varble)s=%(start)s; %(varble)s<%(end)s; %(varble)s++){"
for i in indices:
# GLSL arrays start at 0 and end at dimension-1
open_lines.append(loopstart % {
'varble': self._print(i.label),
'start': self._print(i.lower),
'end': self._print(i.upper + 1)})
close_lines.append("}")
return open_lines, close_lines
def _print_Function_with_args(self, func, func_args):
if func in self.known_functions:
cond_func = self.known_functions[func]
func = None
if isinstance(cond_func, str):
func = cond_func
else:
for cond, func in cond_func:
if cond(func_args):
break
if func is not None:
try:
return func(*[self.parenthesize(item, 0) for item in func_args])
except TypeError:
return '{}({})'.format(func, self.stringify(func_args, ", "))
elif isinstance(func, Lambda):
# inlined function
return self._print(func(*func_args))
else:
return self._print_not_supported(func)
def _print_Piecewise(self, expr):
from sympy.codegen.ast import Assignment
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
lines = []
if expr.has(Assignment):
for i, (e, c) in enumerate(expr.args):
if i == 0:
lines.append("if (%s) {" % self._print(c))
elif i == len(expr.args) - 1 and c == True:
lines.append("else {")
else:
lines.append("else if (%s) {" % self._print(c))
code0 = self._print(e)
lines.append(code0)
lines.append("}")
return "\n".join(lines)
else:
# The piecewise was used in an expression, need to do inline
# operators. This has the downside that inline operators will
# not work for statements that span multiple lines (Matrix or
# Indexed expressions).
ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c),
self._print(e))
for e, c in expr.args[:-1]]
last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr)
return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)])
def _print_Idx(self, expr):
return self._print(expr.label)
def _print_Indexed(self, expr):
# calculate index for 1d array
dims = expr.shape
elem = S.Zero
offset = S.One
for i in reversed(range(expr.rank)):
elem += expr.indices[i]*offset
offset *= dims[i]
return "{}[{}]".format(
self._print(expr.base.label),
self._print(elem)
)
def _print_Pow(self, expr):
PREC = precedence(expr)
if equal_valued(expr.exp, -1):
return '1.0/%s' % (self.parenthesize(expr.base, PREC))
elif equal_valued(expr.exp, 0.5):
return 'sqrt(%s)' % self._print(expr.base)
else:
try:
e = self._print(float(expr.exp))
except TypeError:
e = self._print(expr.exp)
return self._print_Function_with_args('pow', (
self._print(expr.base),
e
))
def _print_int(self, expr):
return str(float(expr))
def _print_Rational(self, expr):
return "{}.0/{}.0".format(expr.p, expr.q)
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
return "{} {} {}".format(lhs_code, op, rhs_code)
def _print_Add(self, expr, order=None):
if self._settings['use_operators']:
return CodePrinter._print_Add(self, expr, order=order)
terms = expr.as_ordered_terms()
def partition(p,l):
return reduce(lambda x, y: (x[0]+[y], x[1]) if p(y) else (x[0], x[1]+[y]), l, ([], []))
def add(a,b):
return self._print_Function_with_args('add', (a, b))
# return self.known_functions['add']+'(%s, %s)' % (a,b)
neg, pos = partition(lambda arg: arg.could_extract_minus_sign(), terms)
if pos:
s = pos = reduce(lambda a,b: add(a,b), (self._print(t) for t in pos))
else:
s = pos = self._print(self._settings['zero'])
if neg:
# sum the absolute values of the negative terms
neg = reduce(lambda a,b: add(a,b), (self._print(-n) for n in neg))
# then subtract them from the positive terms
s = self._print_Function_with_args('sub', (pos,neg))
# s = self.known_functions['sub']+'(%s, %s)' % (pos,neg)
return s
def _print_Mul(self, expr, **kwargs):
if self._settings['use_operators']:
return CodePrinter._print_Mul(self, expr, **kwargs)
terms = expr.as_ordered_factors()
def mul(a,b):
# return self.known_functions['mul']+'(%s, %s)' % (a,b)
return self._print_Function_with_args('mul', (a,b))
s = reduce(lambda a,b: mul(a,b), (self._print(t) for t in terms))
return s
def glsl_code(expr,assign_to=None,**settings):
"""Converts an expr to a string of GLSL code
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used for naming the variable or variables
to which the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol`` or ``Indexed`` type object. In cases where ``expr``
would be printed as an array, a list of string or ``Symbol`` objects
can also be passed.
This is helpful in case of line-wrapping, or for expressions that
generate multi-line statements. It can also be used to spread an array-like
expression into multiple assignments.
use_operators: bool, optional
If set to False, then *,/,+,- operators will be replaced with functions
mul, add, and sub, which must be implemented by the user, e.g. for
implementing non-standard rings or emulated quad/octal precision.
[default=True]
glsl_types: bool, optional
Set this argument to ``False`` in order to avoid using the ``vec`` and ``mat``
types. The printer will instead use arrays (or nested arrays).
[default=True]
mat_nested: bool, optional
GLSL version 4.3 and above support nested arrays (arrays of arrays). Set this to ``True``
to render matrices as nested arrays.
[default=False]
mat_separator: str, optional
By default, matrices are rendered with newlines using this separator,
making them easier to read, but less compact. By removing the newline
this option can be used to make them more vertically compact.
[default=',\n']
mat_transpose: bool, optional
GLSL's matrix multiplication implementation assumes column-major indexing.
By default, this printer ignores that convention. Setting this option to
``True`` transposes all matrix output.
[default=False]
array_type: str, optional
The GLSL array constructor type.
[default='float']
precision : integer, optional
The precision for numbers such as pi [default=15].
user_functions : dict, optional
A dictionary where keys are ``FunctionClass`` instances and values are
their string representations. Alternatively, the dictionary value can
be a list of tuples i.e. [(argument_test, js_function_string)]. See
below for examples.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
Examples
========
>>> from sympy import glsl_code, symbols, Rational, sin, ceiling, Abs
>>> x, tau = symbols("x, tau")
>>> glsl_code((2*tau)**Rational(7, 2))
'8*sqrt(2)*pow(tau, 3.5)'
>>> glsl_code(sin(x), assign_to="float y")
'float y = sin(x);'
Various GLSL types are supported:
>>> from sympy import Matrix, glsl_code
>>> glsl_code(Matrix([1,2,3]))
'vec3(1, 2, 3)'
>>> glsl_code(Matrix([[1, 2],[3, 4]]))
'mat2(1, 2, 3, 4)'
Pass ``mat_transpose = True`` to switch to column-major indexing:
>>> glsl_code(Matrix([[1, 2],[3, 4]]), mat_transpose = True)
'mat2(1, 3, 2, 4)'
By default, larger matrices get collapsed into float arrays:
>>> print(glsl_code( Matrix([[1,2,3,4,5],[6,7,8,9,10]]) ))
float[10](
1, 2, 3, 4, 5,
6, 7, 8, 9, 10
) /* a 2x5 matrix */
The type of array constructor used to print GLSL arrays can be controlled
via the ``array_type`` parameter:
>>> glsl_code(Matrix([1,2,3,4,5]), array_type='int')
'int[5](1, 2, 3, 4, 5)'
Passing a list of strings or ``symbols`` to the ``assign_to`` parameter will yield
a multi-line assignment for each item in an array-like expression:
>>> x_struct_members = symbols('x.a x.b x.c x.d')
>>> print(glsl_code(Matrix([1,2,3,4]), assign_to=x_struct_members))
x.a = 1;
x.b = 2;
x.c = 3;
x.d = 4;
This could be useful in cases where it's desirable to modify members of a
GLSL ``Struct``. It could also be used to spread items from an array-like
expression into various miscellaneous assignments:
>>> misc_assignments = ('x[0]', 'x[1]', 'float y', 'float z')
>>> print(glsl_code(Matrix([1,2,3,4]), assign_to=misc_assignments))
x[0] = 1;
x[1] = 2;
float y = 3;
float z = 4;
Passing ``mat_nested = True`` instead prints out nested float arrays, which are
supported in GLSL 4.3 and above.
>>> mat = Matrix([
... [ 0, 1, 2],
... [ 3, 4, 5],
... [ 6, 7, 8],
... [ 9, 10, 11],
... [12, 13, 14]])
>>> print(glsl_code( mat, mat_nested = True ))
float[5][3](
float[]( 0, 1, 2),
float[]( 3, 4, 5),
float[]( 6, 7, 8),
float[]( 9, 10, 11),
float[](12, 13, 14)
)
Custom printing can be defined for certain types by passing a dictionary of
"type" : "function" to the ``user_functions`` kwarg. Alternatively, the
dictionary value can be a list of tuples i.e. [(argument_test,
js_function_string)].
>>> custom_functions = {
... "ceiling": "CEIL",
... "Abs": [(lambda x: not x.is_integer, "fabs"),
... (lambda x: x.is_integer, "ABS")]
... }
>>> glsl_code(Abs(x) + ceiling(x), user_functions=custom_functions)
'fabs(x) + CEIL(x)'
If further control is needed, addition, subtraction, multiplication and
division operators can be replaced with ``add``, ``sub``, and ``mul``
functions. This is done by passing ``use_operators = False``:
>>> x,y,z = symbols('x,y,z')
>>> glsl_code(x*(y+z), use_operators = False)
'mul(x, add(y, z))'
>>> glsl_code(x*(y+z*(x-y)**z), use_operators = False)
'mul(x, add(y, mul(z, pow(sub(x, y), z))))'
``Piecewise`` expressions are converted into conditionals. If an
``assign_to`` variable is provided an if statement is created, otherwise
the ternary operator is used. Note that if the ``Piecewise`` lacks a
default term, represented by ``(expr, True)`` then an error will be thrown.
This is to prevent generating an expression that may not evaluate to
anything.
>>> from sympy import Piecewise
>>> expr = Piecewise((x + 1, x > 0), (x, True))
>>> print(glsl_code(expr, tau))
if (x > 0) {
tau = x + 1;
}
else {
tau = x;
}
Support for loops is provided through ``Indexed`` types. With
``contract=True`` these expressions will be turned into loops, whereas
``contract=False`` will just print the assignment expression that should be
looped over:
>>> from sympy import Eq, IndexedBase, Idx
>>> len_y = 5
>>> y = IndexedBase('y', shape=(len_y,))
>>> t = IndexedBase('t', shape=(len_y,))
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
>>> i = Idx('i', len_y-1)
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
>>> glsl_code(e.rhs, assign_to=e.lhs, contract=False)
'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
>>> from sympy import Matrix, MatrixSymbol
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
>>> A = MatrixSymbol('A', 3, 1)
>>> print(glsl_code(mat, A))
A[0][0] = pow(x, 2.0);
if (x > 0) {
A[1][0] = x + 1;
}
else {
A[1][0] = x;
}
A[2][0] = sin(x);
"""
return GLSLPrinter(settings).doprint(expr,assign_to)
def print_glsl(expr, **settings):
"""Prints the GLSL representation of the given expression.
See GLSLPrinter init function for settings.
"""
print(glsl_code(expr, **settings))

View File

@ -0,0 +1,16 @@
from sympy.printing.mathml import mathml
from sympy.utilities.mathml import c2p
import tempfile
import subprocess
def print_gtk(x, start_viewer=True):
"""Print to Gtkmathview, a gtk widget capable of rendering MathML.
Needs libgtkmathview-bin"""
with tempfile.NamedTemporaryFile('w') as file:
file.write(c2p(mathml(x), simple=True))
file.flush()
if start_viewer:
subprocess.check_call(('mathmlviewer', file.name))

View File

@ -0,0 +1,335 @@
"""
Javascript code printer
The JavascriptCodePrinter converts single SymPy expressions into single
Javascript expressions, using the functions defined in the Javascript
Math object where possible.
"""
from __future__ import annotations
from typing import Any
from sympy.core import S
from sympy.core.numbers import equal_valued
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence, PRECEDENCE
# dictionary mapping SymPy function to (argument_conditions, Javascript_function).
# Used in JavascriptCodePrinter._print_Function(self)
known_functions = {
'Abs': 'Math.abs',
'acos': 'Math.acos',
'acosh': 'Math.acosh',
'asin': 'Math.asin',
'asinh': 'Math.asinh',
'atan': 'Math.atan',
'atan2': 'Math.atan2',
'atanh': 'Math.atanh',
'ceiling': 'Math.ceil',
'cos': 'Math.cos',
'cosh': 'Math.cosh',
'exp': 'Math.exp',
'floor': 'Math.floor',
'log': 'Math.log',
'Max': 'Math.max',
'Min': 'Math.min',
'sign': 'Math.sign',
'sin': 'Math.sin',
'sinh': 'Math.sinh',
'tan': 'Math.tan',
'tanh': 'Math.tanh',
}
class JavascriptCodePrinter(CodePrinter):
""""A Printer to convert Python expressions to strings of JavaScript code
"""
printmethod = '_javascript'
language = 'JavaScript'
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
'precision': 17,
'user_functions': {},
'contract': True,
})
def __init__(self, settings={}):
CodePrinter.__init__(self, settings)
self.known_functions = dict(known_functions)
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
def _rate_index_position(self, p):
return p*5
def _get_statement(self, codestring):
return "%s;" % codestring
def _get_comment(self, text):
return "// {}".format(text)
def _declare_number_const(self, name, value):
return "var {} = {};".format(name, value.evalf(self._settings['precision']))
def _format_code(self, lines):
return self.indent_code(lines)
def _traverse_matrix_indices(self, mat):
rows, cols = mat.shape
return ((i, j) for i in range(rows) for j in range(cols))
def _get_loop_opening_ending(self, indices):
open_lines = []
close_lines = []
loopstart = "for (var %(varble)s=%(start)s; %(varble)s<%(end)s; %(varble)s++){"
for i in indices:
# Javascript arrays start at 0 and end at dimension-1
open_lines.append(loopstart % {
'varble': self._print(i.label),
'start': self._print(i.lower),
'end': self._print(i.upper + 1)})
close_lines.append("}")
return open_lines, close_lines
def _print_Pow(self, expr):
PREC = precedence(expr)
if equal_valued(expr.exp, -1):
return '1/%s' % (self.parenthesize(expr.base, PREC))
elif equal_valued(expr.exp, 0.5):
return 'Math.sqrt(%s)' % self._print(expr.base)
elif expr.exp == S.One/3:
return 'Math.cbrt(%s)' % self._print(expr.base)
else:
return 'Math.pow(%s, %s)' % (self._print(expr.base),
self._print(expr.exp))
def _print_Rational(self, expr):
p, q = int(expr.p), int(expr.q)
return '%d/%d' % (p, q)
def _print_Mod(self, expr):
num, den = expr.args
PREC = precedence(expr)
snum, sden = [self.parenthesize(arg, PREC) for arg in expr.args]
# % is remainder (same sign as numerator), not modulo (same sign as
# denominator), in js. Hence, % only works as modulo if both numbers
# have the same sign
if (num.is_nonnegative and den.is_nonnegative or
num.is_nonpositive and den.is_nonpositive):
return f"{snum} % {sden}"
return f"(({snum} % {sden}) + {sden}) % {sden}"
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
return "{} {} {}".format(lhs_code, op, rhs_code)
def _print_Indexed(self, expr):
# calculate index for 1d array
dims = expr.shape
elem = S.Zero
offset = S.One
for i in reversed(range(expr.rank)):
elem += expr.indices[i]*offset
offset *= dims[i]
return "%s[%s]" % (self._print(expr.base.label), self._print(elem))
def _print_Idx(self, expr):
return self._print(expr.label)
def _print_Exp1(self, expr):
return "Math.E"
def _print_Pi(self, expr):
return 'Math.PI'
def _print_Infinity(self, expr):
return 'Number.POSITIVE_INFINITY'
def _print_NegativeInfinity(self, expr):
return 'Number.NEGATIVE_INFINITY'
def _print_Piecewise(self, expr):
from sympy.codegen.ast import Assignment
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
lines = []
if expr.has(Assignment):
for i, (e, c) in enumerate(expr.args):
if i == 0:
lines.append("if (%s) {" % self._print(c))
elif i == len(expr.args) - 1 and c == True:
lines.append("else {")
else:
lines.append("else if (%s) {" % self._print(c))
code0 = self._print(e)
lines.append(code0)
lines.append("}")
return "\n".join(lines)
else:
# The piecewise was used in an expression, need to do inline
# operators. This has the downside that inline operators will
# not work for statements that span multiple lines (Matrix or
# Indexed expressions).
ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c), self._print(e))
for e, c in expr.args[:-1]]
last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr)
return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)])
def _print_MatrixElement(self, expr):
return "{}[{}]".format(self.parenthesize(expr.parent,
PRECEDENCE["Atom"], strict=True),
expr.j + expr.i*expr.parent.shape[1])
def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
if isinstance(code, str):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
tab = " "
inc_token = ('{', '(', '{\n', '(\n')
dec_token = ('}', ')')
code = [ line.lstrip(' \t') for line in code ]
increase = [ int(any(map(line.endswith, inc_token))) for line in code ]
decrease = [ int(any(map(line.startswith, dec_token)))
for line in code ]
pretty = []
level = 0
for n, line in enumerate(code):
if line in ('', '\n'):
pretty.append(line)
continue
level -= decrease[n]
pretty.append("%s%s" % (tab*level, line))
level += increase[n]
return pretty
def jscode(expr, assign_to=None, **settings):
"""Converts an expr to a string of javascript code
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
line-wrapping, or for expressions that generate multi-line statements.
precision : integer, optional
The precision for numbers such as pi [default=15].
user_functions : dict, optional
A dictionary where keys are ``FunctionClass`` instances and values are
their string representations. Alternatively, the dictionary value can
be a list of tuples i.e. [(argument_test, js_function_string)]. See
below for examples.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
Examples
========
>>> from sympy import jscode, symbols, Rational, sin, ceiling, Abs
>>> x, tau = symbols("x, tau")
>>> jscode((2*tau)**Rational(7, 2))
'8*Math.sqrt(2)*Math.pow(tau, 7/2)'
>>> jscode(sin(x), assign_to="s")
's = Math.sin(x);'
Custom printing can be defined for certain types by passing a dictionary of
"type" : "function" to the ``user_functions`` kwarg. Alternatively, the
dictionary value can be a list of tuples i.e. [(argument_test,
js_function_string)].
>>> custom_functions = {
... "ceiling": "CEIL",
... "Abs": [(lambda x: not x.is_integer, "fabs"),
... (lambda x: x.is_integer, "ABS")]
... }
>>> jscode(Abs(x) + ceiling(x), user_functions=custom_functions)
'fabs(x) + CEIL(x)'
``Piecewise`` expressions are converted into conditionals. If an
``assign_to`` variable is provided an if statement is created, otherwise
the ternary operator is used. Note that if the ``Piecewise`` lacks a
default term, represented by ``(expr, True)`` then an error will be thrown.
This is to prevent generating an expression that may not evaluate to
anything.
>>> from sympy import Piecewise
>>> expr = Piecewise((x + 1, x > 0), (x, True))
>>> print(jscode(expr, tau))
if (x > 0) {
tau = x + 1;
}
else {
tau = x;
}
Support for loops is provided through ``Indexed`` types. With
``contract=True`` these expressions will be turned into loops, whereas
``contract=False`` will just print the assignment expression that should be
looped over:
>>> from sympy import Eq, IndexedBase, Idx
>>> len_y = 5
>>> y = IndexedBase('y', shape=(len_y,))
>>> t = IndexedBase('t', shape=(len_y,))
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
>>> i = Idx('i', len_y-1)
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
>>> jscode(e.rhs, assign_to=e.lhs, contract=False)
'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
must be provided to ``assign_to``. Note that any expression that can be
generated normally can also exist inside a Matrix:
>>> from sympy import Matrix, MatrixSymbol
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
>>> A = MatrixSymbol('A', 3, 1)
>>> print(jscode(mat, A))
A[0] = Math.pow(x, 2);
if (x > 0) {
A[1] = x + 1;
}
else {
A[1] = x;
}
A[2] = Math.sin(x);
"""
return JavascriptCodePrinter(settings).doprint(expr, assign_to)
def print_jscode(expr, **settings):
"""Prints the Javascript representation of the given expression.
See jscode for the meaning of the optional arguments.
"""
print(jscode(expr, **settings))

View File

@ -0,0 +1,654 @@
"""
Julia code printer
The `JuliaCodePrinter` converts SymPy expressions into Julia expressions.
A complete code generator, which uses `julia_code` extensively, can be found
in `sympy.utilities.codegen`. The `codegen` module can be used to generate
complete source code files.
"""
from __future__ import annotations
from typing import Any
from sympy.core import Mul, Pow, S, Rational
from sympy.core.mul import _keep_coeff
from sympy.core.numbers import equal_valued
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence, PRECEDENCE
from re import search
# List of known functions. First, those that have the same name in
# SymPy and Julia. This is almost certainly incomplete!
known_fcns_src1 = ["sin", "cos", "tan", "cot", "sec", "csc",
"asin", "acos", "atan", "acot", "asec", "acsc",
"sinh", "cosh", "tanh", "coth", "sech", "csch",
"asinh", "acosh", "atanh", "acoth", "asech", "acsch",
"sinc", "atan2", "sign", "floor", "log", "exp",
"cbrt", "sqrt", "erf", "erfc", "erfi",
"factorial", "gamma", "digamma", "trigamma",
"polygamma", "beta",
"airyai", "airyaiprime", "airybi", "airybiprime",
"besselj", "bessely", "besseli", "besselk",
"erfinv", "erfcinv"]
# These functions have different names ("SymPy": "Julia"), more
# generally a mapping to (argument_conditions, julia_function).
known_fcns_src2 = {
"Abs": "abs",
"ceiling": "ceil",
"conjugate": "conj",
"hankel1": "hankelh1",
"hankel2": "hankelh2",
"im": "imag",
"re": "real"
}
class JuliaCodePrinter(CodePrinter):
"""
A printer to convert expressions to strings of Julia code.
"""
printmethod = "_julia"
language = "Julia"
_operators = {
'and': '&&',
'or': '||',
'not': '!',
}
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
'precision': 17,
'user_functions': {},
'contract': True,
'inline': True,
})
# Note: contract is for expressing tensors as loops (if True), or just
# assignment (if False). FIXME: this should be looked a more carefully
# for Julia.
def __init__(self, settings={}):
super().__init__(settings)
self.known_functions = dict(zip(known_fcns_src1, known_fcns_src1))
self.known_functions.update(dict(known_fcns_src2))
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
def _rate_index_position(self, p):
return p*5
def _get_statement(self, codestring):
return "%s" % codestring
def _get_comment(self, text):
return "# {}".format(text)
def _declare_number_const(self, name, value):
return "const {} = {}".format(name, value)
def _format_code(self, lines):
return self.indent_code(lines)
def _traverse_matrix_indices(self, mat):
# Julia uses Fortran order (column-major)
rows, cols = mat.shape
return ((i, j) for j in range(cols) for i in range(rows))
def _get_loop_opening_ending(self, indices):
open_lines = []
close_lines = []
for i in indices:
# Julia arrays start at 1 and end at dimension
var, start, stop = map(self._print,
[i.label, i.lower + 1, i.upper + 1])
open_lines.append("for %s = %s:%s" % (var, start, stop))
close_lines.append("end")
return open_lines, close_lines
def _print_Mul(self, expr):
# print complex numbers nicely in Julia
if (expr.is_number and expr.is_imaginary and
expr.as_coeff_Mul()[0].is_integer):
return "%sim" % self._print(-S.ImaginaryUnit*expr)
# cribbed from str.py
prec = precedence(expr)
c, e = expr.as_coeff_Mul()
if c < 0:
expr = _keep_coeff(-c, e)
sign = "-"
else:
sign = ""
a = [] # items in the numerator
b = [] # items that are in the denominator (if any)
pow_paren = [] # Will collect all pow with more than one base element and exp = -1
if self.order not in ('old', 'none'):
args = expr.as_ordered_factors()
else:
# use make_args in case expr was something like -x -> x
args = Mul.make_args(expr)
# Gather args for numerator/denominator
for item in args:
if (item.is_commutative and item.is_Pow and item.exp.is_Rational
and item.exp.is_negative):
if item.exp != -1:
b.append(Pow(item.base, -item.exp, evaluate=False))
else:
if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160
pow_paren.append(item)
b.append(Pow(item.base, -item.exp))
elif item.is_Rational and item is not S.Infinity and item.p == 1:
# Save the Rational type in julia Unless the numerator is 1.
# For example:
# julia_code(Rational(3, 7)*x) --> (3 // 7) * x
# julia_code(x/3) --> x / 3 but not x * (1 // 3)
b.append(Rational(item.q))
else:
a.append(item)
a = a or [S.One]
a_str = [self.parenthesize(x, prec) for x in a]
b_str = [self.parenthesize(x, prec) for x in b]
# To parenthesize Pow with exp = -1 and having more than one Symbol
for item in pow_paren:
if item.base in b:
b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)]
# from here it differs from str.py to deal with "*" and ".*"
def multjoin(a, a_str):
# here we probably are assuming the constants will come first
r = a_str[0]
for i in range(1, len(a)):
mulsym = '*' if a[i-1].is_number else '.*'
r = "%s %s %s" % (r, mulsym, a_str[i])
return r
if not b:
return sign + multjoin(a, a_str)
elif len(b) == 1:
divsym = '/' if b[0].is_number else './'
return "%s %s %s" % (sign+multjoin(a, a_str), divsym, b_str[0])
else:
divsym = '/' if all(bi.is_number for bi in b) else './'
return "%s %s (%s)" % (sign + multjoin(a, a_str), divsym, multjoin(b, b_str))
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
return "{} {} {}".format(lhs_code, op, rhs_code)
def _print_Pow(self, expr):
powsymbol = '^' if all(x.is_number for x in expr.args) else '.^'
PREC = precedence(expr)
if equal_valued(expr.exp, 0.5):
return "sqrt(%s)" % self._print(expr.base)
if expr.is_commutative:
if equal_valued(expr.exp, -0.5):
sym = '/' if expr.base.is_number else './'
return "1 %s sqrt(%s)" % (sym, self._print(expr.base))
if equal_valued(expr.exp, -1):
sym = '/' if expr.base.is_number else './'
return "1 %s %s" % (sym, self.parenthesize(expr.base, PREC))
return '%s %s %s' % (self.parenthesize(expr.base, PREC), powsymbol,
self.parenthesize(expr.exp, PREC))
def _print_MatPow(self, expr):
PREC = precedence(expr)
return '%s ^ %s' % (self.parenthesize(expr.base, PREC),
self.parenthesize(expr.exp, PREC))
def _print_Pi(self, expr):
if self._settings["inline"]:
return "pi"
else:
return super()._print_NumberSymbol(expr)
def _print_ImaginaryUnit(self, expr):
return "im"
def _print_Exp1(self, expr):
if self._settings["inline"]:
return "e"
else:
return super()._print_NumberSymbol(expr)
def _print_EulerGamma(self, expr):
if self._settings["inline"]:
return "eulergamma"
else:
return super()._print_NumberSymbol(expr)
def _print_Catalan(self, expr):
if self._settings["inline"]:
return "catalan"
else:
return super()._print_NumberSymbol(expr)
def _print_GoldenRatio(self, expr):
if self._settings["inline"]:
return "golden"
else:
return super()._print_NumberSymbol(expr)
def _print_Assignment(self, expr):
from sympy.codegen.ast import Assignment
from sympy.functions.elementary.piecewise import Piecewise
from sympy.tensor.indexed import IndexedBase
# Copied from codeprinter, but remove special MatrixSymbol treatment
lhs = expr.lhs
rhs = expr.rhs
# We special case assignments that take multiple lines
if not self._settings["inline"] and isinstance(expr.rhs, Piecewise):
# Here we modify Piecewise so each expression is now
# an Assignment, and then continue on the print.
expressions = []
conditions = []
for (e, c) in rhs.args:
expressions.append(Assignment(lhs, e))
conditions.append(c)
temp = Piecewise(*zip(expressions, conditions))
return self._print(temp)
if self._settings["contract"] and (lhs.has(IndexedBase) or
rhs.has(IndexedBase)):
# Here we check if there is looping to be done, and if so
# print the required loops.
return self._doprint_loops(rhs, lhs)
else:
lhs_code = self._print(lhs)
rhs_code = self._print(rhs)
return self._get_statement("%s = %s" % (lhs_code, rhs_code))
def _print_Infinity(self, expr):
return 'Inf'
def _print_NegativeInfinity(self, expr):
return '-Inf'
def _print_NaN(self, expr):
return 'NaN'
def _print_list(self, expr):
return 'Any[' + ', '.join(self._print(a) for a in expr) + ']'
def _print_tuple(self, expr):
if len(expr) == 1:
return "(%s,)" % self._print(expr[0])
else:
return "(%s)" % self.stringify(expr, ", ")
_print_Tuple = _print_tuple
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
def _print_bool(self, expr):
return str(expr).lower()
# Could generate quadrature code for definite Integrals?
#_print_Integral = _print_not_supported
def _print_MatrixBase(self, A):
# Handle zero dimensions:
if S.Zero in A.shape:
return 'zeros(%s, %s)' % (A.rows, A.cols)
elif (A.rows, A.cols) == (1, 1):
return "[%s]" % A[0, 0]
elif A.rows == 1:
return "[%s]" % A.table(self, rowstart='', rowend='', colsep=' ')
elif A.cols == 1:
# note .table would unnecessarily equispace the rows
return "[%s]" % ", ".join([self._print(a) for a in A])
return "[%s]" % A.table(self, rowstart='', rowend='',
rowsep=';\n', colsep=' ')
def _print_SparseRepMatrix(self, A):
from sympy.matrices import Matrix
L = A.col_list();
# make row vectors of the indices and entries
I = Matrix([k[0] + 1 for k in L])
J = Matrix([k[1] + 1 for k in L])
AIJ = Matrix([k[2] for k in L])
return "sparse(%s, %s, %s, %s, %s)" % (self._print(I), self._print(J),
self._print(AIJ), A.rows, A.cols)
def _print_MatrixElement(self, expr):
return self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) \
+ '[%s,%s]' % (expr.i + 1, expr.j + 1)
def _print_MatrixSlice(self, expr):
def strslice(x, lim):
l = x[0] + 1
h = x[1]
step = x[2]
lstr = self._print(l)
hstr = 'end' if h == lim else self._print(h)
if step == 1:
if l == 1 and h == lim:
return ':'
if l == h:
return lstr
else:
return lstr + ':' + hstr
else:
return ':'.join((lstr, self._print(step), hstr))
return (self._print(expr.parent) + '[' +
strslice(expr.rowslice, expr.parent.shape[0]) + ',' +
strslice(expr.colslice, expr.parent.shape[1]) + ']')
def _print_Indexed(self, expr):
inds = [ self._print(i) for i in expr.indices ]
return "%s[%s]" % (self._print(expr.base.label), ",".join(inds))
def _print_Idx(self, expr):
return self._print(expr.label)
def _print_Identity(self, expr):
return "eye(%s)" % self._print(expr.shape[0])
def _print_HadamardProduct(self, expr):
return ' .* '.join([self.parenthesize(arg, precedence(expr))
for arg in expr.args])
def _print_HadamardPower(self, expr):
PREC = precedence(expr)
return '.**'.join([
self.parenthesize(expr.base, PREC),
self.parenthesize(expr.exp, PREC)
])
def _print_Rational(self, expr):
if expr.q == 1:
return str(expr.p)
return "%s // %s" % (expr.p, expr.q)
# Note: as of 2022, Julia doesn't have spherical Bessel functions
def _print_jn(self, expr):
from sympy.functions import sqrt, besselj
x = expr.argument
expr2 = sqrt(S.Pi/(2*x))*besselj(expr.order + S.Half, x)
return self._print(expr2)
def _print_yn(self, expr):
from sympy.functions import sqrt, bessely
x = expr.argument
expr2 = sqrt(S.Pi/(2*x))*bessely(expr.order + S.Half, x)
return self._print(expr2)
def _print_Piecewise(self, expr):
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
lines = []
if self._settings["inline"]:
# Express each (cond, expr) pair in a nested Horner form:
# (condition) .* (expr) + (not cond) .* (<others>)
# Expressions that result in multiple statements won't work here.
ecpairs = ["({}) ? ({}) :".format
(self._print(c), self._print(e))
for e, c in expr.args[:-1]]
elast = " (%s)" % self._print(expr.args[-1].expr)
pw = "\n".join(ecpairs) + elast
# Note: current need these outer brackets for 2*pw. Would be
# nicer to teach parenthesize() to do this for us when needed!
return "(" + pw + ")"
else:
for i, (e, c) in enumerate(expr.args):
if i == 0:
lines.append("if (%s)" % self._print(c))
elif i == len(expr.args) - 1 and c == True:
lines.append("else")
else:
lines.append("elseif (%s)" % self._print(c))
code0 = self._print(e)
lines.append(code0)
if i == len(expr.args) - 1:
lines.append("end")
return "\n".join(lines)
def _print_MatMul(self, expr):
c, m = expr.as_coeff_mmul()
sign = ""
if c.is_number:
re, im = c.as_real_imag()
if im.is_zero and re.is_negative:
expr = _keep_coeff(-c, m)
sign = "-"
elif re.is_zero and im.is_negative:
expr = _keep_coeff(-c, m)
sign = "-"
return sign + ' * '.join(
(self.parenthesize(arg, precedence(expr)) for arg in expr.args)
)
def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
# code mostly copied from ccode
if isinstance(code, str):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
tab = " "
inc_regex = ('^function ', '^if ', '^elseif ', '^else$', '^for ')
dec_regex = ('^end$', '^elseif ', '^else$')
# pre-strip left-space from the code
code = [ line.lstrip(' \t') for line in code ]
increase = [ int(any(search(re, line) for re in inc_regex))
for line in code ]
decrease = [ int(any(search(re, line) for re in dec_regex))
for line in code ]
pretty = []
level = 0
for n, line in enumerate(code):
if line in ('', '\n'):
pretty.append(line)
continue
level -= decrease[n]
pretty.append("%s%s" % (tab*level, line))
level += increase[n]
return pretty
def julia_code(expr, assign_to=None, **settings):
r"""Converts `expr` to a string of Julia code.
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This can be helpful for
expressions that generate multi-line statements.
precision : integer, optional
The precision for numbers such as pi [default=16].
user_functions : dict, optional
A dictionary where keys are ``FunctionClass`` instances and values are
their string representations. Alternatively, the dictionary value can
be a list of tuples i.e. [(argument_test, cfunction_string)]. See
below for examples.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
inline: bool, optional
If True, we try to create single-statement code instead of multiple
statements. [default=True].
Examples
========
>>> from sympy import julia_code, symbols, sin, pi
>>> x = symbols('x')
>>> julia_code(sin(x).series(x).removeO())
'x .^ 5 / 120 - x .^ 3 / 6 + x'
>>> from sympy import Rational, ceiling
>>> x, y, tau = symbols("x, y, tau")
>>> julia_code((2*tau)**Rational(7, 2))
'8 * sqrt(2) * tau .^ (7 // 2)'
Note that element-wise (Hadamard) operations are used by default between
symbols. This is because its possible in Julia to write "vectorized"
code. It is harmless if the values are scalars.
>>> julia_code(sin(pi*x*y), assign_to="s")
's = sin(pi * x .* y)'
If you need a matrix product "*" or matrix power "^", you can specify the
symbol as a ``MatrixSymbol``.
>>> from sympy import Symbol, MatrixSymbol
>>> n = Symbol('n', integer=True, positive=True)
>>> A = MatrixSymbol('A', n, n)
>>> julia_code(3*pi*A**3)
'(3 * pi) * A ^ 3'
This class uses several rules to decide which symbol to use a product.
Pure numbers use "*", Symbols use ".*" and MatrixSymbols use "*".
A HadamardProduct can be used to specify componentwise multiplication ".*"
of two MatrixSymbols. There is currently there is no easy way to specify
scalar symbols, so sometimes the code might have some minor cosmetic
issues. For example, suppose x and y are scalars and A is a Matrix, then
while a human programmer might write "(x^2*y)*A^3", we generate:
>>> julia_code(x**2*y*A**3)
'(x .^ 2 .* y) * A ^ 3'
Matrices are supported using Julia inline notation. When using
``assign_to`` with matrices, the name can be specified either as a string
or as a ``MatrixSymbol``. The dimensions must align in the latter case.
>>> from sympy import Matrix, MatrixSymbol
>>> mat = Matrix([[x**2, sin(x), ceiling(x)]])
>>> julia_code(mat, assign_to='A')
'A = [x .^ 2 sin(x) ceil(x)]'
``Piecewise`` expressions are implemented with logical masking by default.
Alternatively, you can pass "inline=False" to use if-else conditionals.
Note that if the ``Piecewise`` lacks a default term, represented by
``(expr, True)`` then an error will be thrown. This is to prevent
generating an expression that may not evaluate to anything.
>>> from sympy import Piecewise
>>> pw = Piecewise((x + 1, x > 0), (x, True))
>>> julia_code(pw, assign_to=tau)
'tau = ((x > 0) ? (x + 1) : (x))'
Note that any expression that can be generated normally can also exist
inside a Matrix:
>>> mat = Matrix([[x**2, pw, sin(x)]])
>>> julia_code(mat, assign_to='A')
'A = [x .^ 2 ((x > 0) ? (x + 1) : (x)) sin(x)]'
Custom printing can be defined for certain types by passing a dictionary of
"type" : "function" to the ``user_functions`` kwarg. Alternatively, the
dictionary value can be a list of tuples i.e., [(argument_test,
cfunction_string)]. This can be used to call a custom Julia function.
>>> from sympy import Function
>>> f = Function('f')
>>> g = Function('g')
>>> custom_functions = {
... "f": "existing_julia_fcn",
... "g": [(lambda x: x.is_Matrix, "my_mat_fcn"),
... (lambda x: not x.is_Matrix, "my_fcn")]
... }
>>> mat = Matrix([[1, x]])
>>> julia_code(f(x) + g(x) + g(mat), user_functions=custom_functions)
'existing_julia_fcn(x) + my_fcn(x) + my_mat_fcn([1 x])'
Support for loops is provided through ``Indexed`` types. With
``contract=True`` these expressions will be turned into loops, whereas
``contract=False`` will just print the assignment expression that should be
looped over:
>>> from sympy import Eq, IndexedBase, Idx
>>> len_y = 5
>>> y = IndexedBase('y', shape=(len_y,))
>>> t = IndexedBase('t', shape=(len_y,))
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
>>> i = Idx('i', len_y-1)
>>> e = Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
>>> julia_code(e.rhs, assign_to=e.lhs, contract=False)
'Dy[i] = (y[i + 1] - y[i]) ./ (t[i + 1] - t[i])'
"""
return JuliaCodePrinter(settings).doprint(expr, assign_to)
def print_julia_code(expr, **settings):
"""Prints the Julia representation of the given expression.
See `julia_code` for the meaning of the optional arguments.
"""
print(julia_code(expr, **settings))

View File

@ -0,0 +1,251 @@
from .pycode import (
PythonCodePrinter,
MpmathPrinter,
)
from .numpy import NumPyPrinter # NumPyPrinter is imported for backward compatibility
from sympy.core.sorting import default_sort_key
__all__ = [
'PythonCodePrinter',
'MpmathPrinter', # MpmathPrinter is published for backward compatibility
'NumPyPrinter',
'LambdaPrinter',
'NumPyPrinter',
'IntervalPrinter',
'lambdarepr',
]
class LambdaPrinter(PythonCodePrinter):
"""
This printer converts expressions into strings that can be used by
lambdify.
"""
printmethod = "_lambdacode"
def _print_And(self, expr):
result = ['(']
for arg in sorted(expr.args, key=default_sort_key):
result.extend(['(', self._print(arg), ')'])
result.append(' and ')
result = result[:-1]
result.append(')')
return ''.join(result)
def _print_Or(self, expr):
result = ['(']
for arg in sorted(expr.args, key=default_sort_key):
result.extend(['(', self._print(arg), ')'])
result.append(' or ')
result = result[:-1]
result.append(')')
return ''.join(result)
def _print_Not(self, expr):
result = ['(', 'not (', self._print(expr.args[0]), '))']
return ''.join(result)
def _print_BooleanTrue(self, expr):
return "True"
def _print_BooleanFalse(self, expr):
return "False"
def _print_ITE(self, expr):
result = [
'((', self._print(expr.args[1]),
') if (', self._print(expr.args[0]),
') else (', self._print(expr.args[2]), '))'
]
return ''.join(result)
def _print_NumberSymbol(self, expr):
return str(expr)
def _print_Pow(self, expr, **kwargs):
# XXX Temporary workaround. Should Python math printer be
# isolated from PythonCodePrinter?
return super(PythonCodePrinter, self)._print_Pow(expr, **kwargs)
# numexpr works by altering the string passed to numexpr.evaluate
# rather than by populating a namespace. Thus a special printer...
class NumExprPrinter(LambdaPrinter):
# key, value pairs correspond to SymPy name and numexpr name
# functions not appearing in this dict will raise a TypeError
printmethod = "_numexprcode"
_numexpr_functions = {
'sin' : 'sin',
'cos' : 'cos',
'tan' : 'tan',
'asin': 'arcsin',
'acos': 'arccos',
'atan': 'arctan',
'atan2' : 'arctan2',
'sinh' : 'sinh',
'cosh' : 'cosh',
'tanh' : 'tanh',
'asinh': 'arcsinh',
'acosh': 'arccosh',
'atanh': 'arctanh',
'ln' : 'log',
'log': 'log',
'exp': 'exp',
'sqrt' : 'sqrt',
'Abs' : 'abs',
'conjugate' : 'conj',
'im' : 'imag',
're' : 'real',
'where' : 'where',
'complex' : 'complex',
'contains' : 'contains',
}
module = 'numexpr'
def _print_ImaginaryUnit(self, expr):
return '1j'
def _print_seq(self, seq, delimiter=', '):
# simplified _print_seq taken from pretty.py
s = [self._print(item) for item in seq]
if s:
return delimiter.join(s)
else:
return ""
def _print_Function(self, e):
func_name = e.func.__name__
nstr = self._numexpr_functions.get(func_name, None)
if nstr is None:
# check for implemented_function
if hasattr(e, '_imp_'):
return "(%s)" % self._print(e._imp_(*e.args))
else:
raise TypeError("numexpr does not support function '%s'" %
func_name)
return "%s(%s)" % (nstr, self._print_seq(e.args))
def _print_Piecewise(self, expr):
"Piecewise function printer"
exprs = [self._print(arg.expr) for arg in expr.args]
conds = [self._print(arg.cond) for arg in expr.args]
# If [default_value, True] is a (expr, cond) sequence in a Piecewise object
# it will behave the same as passing the 'default' kwarg to select()
# *as long as* it is the last element in expr.args.
# If this is not the case, it may be triggered prematurely.
ans = []
parenthesis_count = 0
is_last_cond_True = False
for cond, expr in zip(conds, exprs):
if cond == 'True':
ans.append(expr)
is_last_cond_True = True
break
else:
ans.append('where(%s, %s, ' % (cond, expr))
parenthesis_count += 1
if not is_last_cond_True:
# See https://github.com/pydata/numexpr/issues/298
#
# simplest way to put a nan but raises
# 'RuntimeWarning: invalid value encountered in log'
#
# There are other ways to do this such as
#
# >>> import numexpr as ne
# >>> nan = float('nan')
# >>> ne.evaluate('where(x < 0, -1, nan)', {'x': [-1, 2, 3], 'nan':nan})
# array([-1., nan, nan])
#
# That needs to be handled in the lambdified function though rather
# than here in the printer.
ans.append('log(-1)')
return ''.join(ans) + ')' * parenthesis_count
def _print_ITE(self, expr):
from sympy.functions.elementary.piecewise import Piecewise
return self._print(expr.rewrite(Piecewise))
def blacklisted(self, expr):
raise TypeError("numexpr cannot be used with %s" %
expr.__class__.__name__)
# blacklist all Matrix printing
_print_SparseRepMatrix = \
_print_MutableSparseMatrix = \
_print_ImmutableSparseMatrix = \
_print_Matrix = \
_print_DenseMatrix = \
_print_MutableDenseMatrix = \
_print_ImmutableMatrix = \
_print_ImmutableDenseMatrix = \
blacklisted
# blacklist some Python expressions
_print_list = \
_print_tuple = \
_print_Tuple = \
_print_dict = \
_print_Dict = \
blacklisted
def _print_NumExprEvaluate(self, expr):
evaluate = self._module_format(self.module +".evaluate")
return "%s('%s', truediv=True)" % (evaluate, self._print(expr.expr))
def doprint(self, expr):
from sympy.codegen.ast import CodegenAST
from sympy.codegen.pynodes import NumExprEvaluate
if not isinstance(expr, CodegenAST):
expr = NumExprEvaluate(expr)
return super().doprint(expr)
def _print_Return(self, expr):
from sympy.codegen.pynodes import NumExprEvaluate
r, = expr.args
if not isinstance(r, NumExprEvaluate):
expr = expr.func(NumExprEvaluate(r))
return super()._print_Return(expr)
def _print_Assignment(self, expr):
from sympy.codegen.pynodes import NumExprEvaluate
lhs, rhs, *args = expr.args
if not isinstance(rhs, NumExprEvaluate):
expr = expr.func(lhs, NumExprEvaluate(rhs), *args)
return super()._print_Assignment(expr)
def _print_CodeBlock(self, expr):
from sympy.codegen.ast import CodegenAST
from sympy.codegen.pynodes import NumExprEvaluate
args = [ arg if isinstance(arg, CodegenAST) else NumExprEvaluate(arg) for arg in expr.args ]
return super()._print_CodeBlock(self, expr.func(*args))
class IntervalPrinter(MpmathPrinter, LambdaPrinter):
"""Use ``lambda`` printer but print numbers as ``mpi`` intervals. """
def _print_Integer(self, expr):
return "mpi('%s')" % super(PythonCodePrinter, self)._print_Integer(expr)
def _print_Rational(self, expr):
return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr)
def _print_Half(self, expr):
return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr)
def _print_Pow(self, expr):
return super(MpmathPrinter, self)._print_Pow(expr, rational=True)
for k in NumExprPrinter._numexpr_functions:
setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function)
def lambdarepr(expr, **settings):
"""
Returns a string usable for lambdifying.
"""
return LambdaPrinter(settings).doprint(expr)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,491 @@
'''
Use llvmlite to create executable functions from SymPy expressions
This module requires llvmlite (https://github.com/numba/llvmlite).
'''
import ctypes
from sympy.external import import_module
from sympy.printing.printer import Printer
from sympy.core.singleton import S
from sympy.tensor.indexed import IndexedBase
from sympy.utilities.decorator import doctest_depends_on
llvmlite = import_module('llvmlite')
if llvmlite:
ll = import_module('llvmlite.ir').ir
llvm = import_module('llvmlite.binding').binding
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
__doctest_requires__ = {('llvm_callable'): ['llvmlite']}
class LLVMJitPrinter(Printer):
'''Convert expressions to LLVM IR'''
def __init__(self, module, builder, fn, *args, **kwargs):
self.func_arg_map = kwargs.pop("func_arg_map", {})
if not llvmlite:
raise ImportError("llvmlite is required for LLVMJITPrinter")
super().__init__(*args, **kwargs)
self.fp_type = ll.DoubleType()
self.module = module
self.builder = builder
self.fn = fn
self.ext_fn = {} # keep track of wrappers to external functions
self.tmp_var = {}
def _add_tmp_var(self, name, value):
self.tmp_var[name] = value
def _print_Number(self, n):
return ll.Constant(self.fp_type, float(n))
def _print_Integer(self, expr):
return ll.Constant(self.fp_type, float(expr.p))
def _print_Symbol(self, s):
val = self.tmp_var.get(s)
if not val:
# look up parameter with name s
val = self.func_arg_map.get(s)
if not val:
raise LookupError("Symbol not found: %s" % s)
return val
def _print_Pow(self, expr):
base0 = self._print(expr.base)
if expr.exp == S.NegativeOne:
return self.builder.fdiv(ll.Constant(self.fp_type, 1.0), base0)
if expr.exp == S.Half:
fn = self.ext_fn.get("sqrt")
if not fn:
fn_type = ll.FunctionType(self.fp_type, [self.fp_type])
fn = ll.Function(self.module, fn_type, "sqrt")
self.ext_fn["sqrt"] = fn
return self.builder.call(fn, [base0], "sqrt")
if expr.exp == 2:
return self.builder.fmul(base0, base0)
exp0 = self._print(expr.exp)
fn = self.ext_fn.get("pow")
if not fn:
fn_type = ll.FunctionType(self.fp_type, [self.fp_type, self.fp_type])
fn = ll.Function(self.module, fn_type, "pow")
self.ext_fn["pow"] = fn
return self.builder.call(fn, [base0, exp0], "pow")
def _print_Mul(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
for node in nodes[1:]:
e = self.builder.fmul(e, node)
return e
def _print_Add(self, expr):
nodes = [self._print(a) for a in expr.args]
e = nodes[0]
for node in nodes[1:]:
e = self.builder.fadd(e, node)
return e
# TODO - assumes all called functions take one double precision argument.
# Should have a list of math library functions to validate this.
def _print_Function(self, expr):
name = expr.func.__name__
e0 = self._print(expr.args[0])
fn = self.ext_fn.get(name)
if not fn:
fn_type = ll.FunctionType(self.fp_type, [self.fp_type])
fn = ll.Function(self.module, fn_type, name)
self.ext_fn[name] = fn
return self.builder.call(fn, [e0], name)
def emptyPrinter(self, expr):
raise TypeError("Unsupported type for LLVM JIT conversion: %s"
% type(expr))
# Used when parameters are passed by array. Often used in callbacks to
# handle a variable number of parameters.
class LLVMJitCallbackPrinter(LLVMJitPrinter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _print_Indexed(self, expr):
array, idx = self.func_arg_map[expr.base]
offset = int(expr.indices[0].evalf())
array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), offset)])
fp_array_ptr = self.builder.bitcast(array_ptr, ll.PointerType(self.fp_type))
value = self.builder.load(fp_array_ptr)
return value
def _print_Symbol(self, s):
val = self.tmp_var.get(s)
if val:
return val
array, idx = self.func_arg_map.get(s, [None, 0])
if not array:
raise LookupError("Symbol not found: %s" % s)
array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), idx)])
fp_array_ptr = self.builder.bitcast(array_ptr,
ll.PointerType(self.fp_type))
value = self.builder.load(fp_array_ptr)
return value
# ensure lifetime of the execution engine persists (else call to compiled
# function will seg fault)
exe_engines = []
# ensure names for generated functions are unique
link_names = set()
current_link_suffix = 0
class LLVMJitCode:
def __init__(self, signature):
self.signature = signature
self.fp_type = ll.DoubleType()
self.module = ll.Module('mod1')
self.fn = None
self.llvm_arg_types = []
self.llvm_ret_type = self.fp_type
self.param_dict = {} # map symbol name to LLVM function argument
self.link_name = ''
def _from_ctype(self, ctype):
if ctype == ctypes.c_int:
return ll.IntType(32)
if ctype == ctypes.c_double:
return self.fp_type
if ctype == ctypes.POINTER(ctypes.c_double):
return ll.PointerType(self.fp_type)
if ctype == ctypes.c_void_p:
return ll.PointerType(ll.IntType(32))
if ctype == ctypes.py_object:
return ll.PointerType(ll.IntType(32))
print("Unhandled ctype = %s" % str(ctype))
def _create_args(self, func_args):
"""Create types for function arguments"""
self.llvm_ret_type = self._from_ctype(self.signature.ret_type)
self.llvm_arg_types = \
[self._from_ctype(a) for a in self.signature.arg_ctypes]
def _create_function_base(self):
"""Create function with name and type signature"""
global link_names, current_link_suffix
default_link_name = 'jit_func'
current_link_suffix += 1
self.link_name = default_link_name + str(current_link_suffix)
link_names.add(self.link_name)
fn_type = ll.FunctionType(self.llvm_ret_type, self.llvm_arg_types)
self.fn = ll.Function(self.module, fn_type, name=self.link_name)
def _create_param_dict(self, func_args):
"""Mapping of symbolic values to function arguments"""
for i, a in enumerate(func_args):
self.fn.args[i].name = str(a)
self.param_dict[a] = self.fn.args[i]
def _create_function(self, expr):
"""Create function body and return LLVM IR"""
bb_entry = self.fn.append_basic_block('entry')
builder = ll.IRBuilder(bb_entry)
lj = LLVMJitPrinter(self.module, builder, self.fn,
func_arg_map=self.param_dict)
ret = self._convert_expr(lj, expr)
lj.builder.ret(self._wrap_return(lj, ret))
strmod = str(self.module)
return strmod
def _wrap_return(self, lj, vals):
# Return a single double if there is one return value,
# else return a tuple of doubles.
# Don't wrap return value in this case
if self.signature.ret_type == ctypes.c_double:
return vals[0]
# Use this instead of a real PyObject*
void_ptr = ll.PointerType(ll.IntType(32))
# Create a wrapped double: PyObject* PyFloat_FromDouble(double v)
wrap_type = ll.FunctionType(void_ptr, [self.fp_type])
wrap_fn = ll.Function(lj.module, wrap_type, "PyFloat_FromDouble")
wrapped_vals = [lj.builder.call(wrap_fn, [v]) for v in vals]
if len(vals) == 1:
final_val = wrapped_vals[0]
else:
# Create a tuple: PyObject* PyTuple_Pack(Py_ssize_t n, ...)
# This should be Py_ssize_t
tuple_arg_types = [ll.IntType(32)]
tuple_arg_types.extend([void_ptr]*len(vals))
tuple_type = ll.FunctionType(void_ptr, tuple_arg_types)
tuple_fn = ll.Function(lj.module, tuple_type, "PyTuple_Pack")
tuple_args = [ll.Constant(ll.IntType(32), len(wrapped_vals))]
tuple_args.extend(wrapped_vals)
final_val = lj.builder.call(tuple_fn, tuple_args)
return final_val
def _convert_expr(self, lj, expr):
try:
# Match CSE return data structure.
if len(expr) == 2:
tmp_exprs = expr[0]
final_exprs = expr[1]
if len(final_exprs) != 1 and self.signature.ret_type == ctypes.c_double:
raise NotImplementedError("Return of multiple expressions not supported for this callback")
for name, e in tmp_exprs:
val = lj._print(e)
lj._add_tmp_var(name, val)
except TypeError:
final_exprs = [expr]
vals = [lj._print(e) for e in final_exprs]
return vals
def _compile_function(self, strmod):
global exe_engines
llmod = llvm.parse_assembly(strmod)
pmb = llvm.create_pass_manager_builder()
pmb.opt_level = 2
pass_manager = llvm.create_module_pass_manager()
pmb.populate(pass_manager)
pass_manager.run(llmod)
target_machine = \
llvm.Target.from_default_triple().create_target_machine()
exe_eng = llvm.create_mcjit_compiler(llmod, target_machine)
exe_eng.finalize_object()
exe_engines.append(exe_eng)
if False:
print("Assembly")
print(target_machine.emit_assembly(llmod))
fptr = exe_eng.get_function_address(self.link_name)
return fptr
class LLVMJitCodeCallback(LLVMJitCode):
def __init__(self, signature):
super().__init__(signature)
def _create_param_dict(self, func_args):
for i, a in enumerate(func_args):
if isinstance(a, IndexedBase):
self.param_dict[a] = (self.fn.args[i], i)
self.fn.args[i].name = str(a)
else:
self.param_dict[a] = (self.fn.args[self.signature.input_arg],
i)
def _create_function(self, expr):
"""Create function body and return LLVM IR"""
bb_entry = self.fn.append_basic_block('entry')
builder = ll.IRBuilder(bb_entry)
lj = LLVMJitCallbackPrinter(self.module, builder, self.fn,
func_arg_map=self.param_dict)
ret = self._convert_expr(lj, expr)
if self.signature.ret_arg:
output_fp_ptr = builder.bitcast(self.fn.args[self.signature.ret_arg],
ll.PointerType(self.fp_type))
for i, val in enumerate(ret):
index = ll.Constant(ll.IntType(32), i)
output_array_ptr = builder.gep(output_fp_ptr, [index])
builder.store(val, output_array_ptr)
builder.ret(ll.Constant(ll.IntType(32), 0)) # return success
else:
lj.builder.ret(self._wrap_return(lj, ret))
strmod = str(self.module)
return strmod
class CodeSignature:
def __init__(self, ret_type):
self.ret_type = ret_type
self.arg_ctypes = []
# Input argument array element index
self.input_arg = 0
# For the case output value is referenced through a parameter rather
# than the return value
self.ret_arg = None
def _llvm_jit_code(args, expr, signature, callback_type):
"""Create a native code function from a SymPy expression"""
if callback_type is None:
jit = LLVMJitCode(signature)
else:
jit = LLVMJitCodeCallback(signature)
jit._create_args(args)
jit._create_function_base()
jit._create_param_dict(args)
strmod = jit._create_function(expr)
if False:
print("LLVM IR")
print(strmod)
fptr = jit._compile_function(strmod)
return fptr
@doctest_depends_on(modules=('llvmlite', 'scipy'))
def llvm_callable(args, expr, callback_type=None):
'''Compile function from a SymPy expression
Expressions are evaluated using double precision arithmetic.
Some single argument math functions (exp, sin, cos, etc.) are supported
in expressions.
Parameters
==========
args : List of Symbol
Arguments to the generated function. Usually the free symbols in
the expression. Currently each one is assumed to convert to
a double precision scalar.
expr : Expr, or (Replacements, Expr) as returned from 'cse'
Expression to compile.
callback_type : string
Create function with signature appropriate to use as a callback.
Currently supported:
'scipy.integrate'
'scipy.integrate.test'
'cubature'
Returns
=======
Compiled function that can evaluate the expression.
Examples
========
>>> import sympy.printing.llvmjitcode as jit
>>> from sympy.abc import a
>>> e = a*a + a + 1
>>> e1 = jit.llvm_callable([a], e)
>>> e.subs(a, 1.1) # Evaluate via substitution
3.31000000000000
>>> e1(1.1) # Evaluate using JIT-compiled code
3.3100000000000005
Callbacks for integration functions can be JIT compiled.
>>> import sympy.printing.llvmjitcode as jit
>>> from sympy.abc import a
>>> from sympy import integrate
>>> from scipy.integrate import quad
>>> e = a*a
>>> e1 = jit.llvm_callable([a], e, callback_type='scipy.integrate')
>>> integrate(e, (a, 0.0, 2.0))
2.66666666666667
>>> quad(e1, 0.0, 2.0)[0]
2.66666666666667
The 'cubature' callback is for the Python wrapper around the
cubature package ( https://github.com/saullocastro/cubature )
and ( http://ab-initio.mit.edu/wiki/index.php/Cubature )
There are two signatures for the SciPy integration callbacks.
The first ('scipy.integrate') is the function to be passed to the
integration routine, and will pass the signature checks.
The second ('scipy.integrate.test') is only useful for directly calling
the function using ctypes variables. It will not pass the signature checks
for scipy.integrate.
The return value from the cse module can also be compiled. This
can improve the performance of the compiled function. If multiple
expressions are given to cse, the compiled function returns a tuple.
The 'cubature' callback handles multiple expressions (set `fdim`
to match in the integration call.)
>>> import sympy.printing.llvmjitcode as jit
>>> from sympy import cse
>>> from sympy.abc import x,y
>>> e1 = x*x + y*y
>>> e2 = 4*(x*x + y*y) + 8.0
>>> after_cse = cse([e1,e2])
>>> after_cse
([(x0, x**2), (x1, y**2)], [x0 + x1, 4*x0 + 4*x1 + 8.0])
>>> j1 = jit.llvm_callable([x,y], after_cse)
>>> j1(1.0, 2.0)
(5.0, 28.0)
'''
if not llvmlite:
raise ImportError("llvmlite is required for llvmjitcode")
signature = CodeSignature(ctypes.py_object)
arg_ctypes = []
if callback_type is None:
for _ in args:
arg_ctype = ctypes.c_double
arg_ctypes.append(arg_ctype)
elif callback_type in ('scipy.integrate', 'scipy.integrate.test'):
signature.ret_type = ctypes.c_double
arg_ctypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
arg_ctypes_formal = [ctypes.c_int, ctypes.c_double]
signature.input_arg = 1
elif callback_type == 'cubature':
arg_ctypes = [ctypes.c_int,
ctypes.POINTER(ctypes.c_double),
ctypes.c_void_p,
ctypes.c_int,
ctypes.POINTER(ctypes.c_double)
]
signature.ret_type = ctypes.c_int
signature.input_arg = 1
signature.ret_arg = 4
else:
raise ValueError("Unknown callback type: %s" % callback_type)
signature.arg_ctypes = arg_ctypes
fptr = _llvm_jit_code(args, expr, signature, callback_type)
if callback_type and callback_type == 'scipy.integrate':
arg_ctypes = arg_ctypes_formal
# PYFUNCTYPE holds the GIL which is needed to prevent a segfault when
# calling PyFloat_FromDouble on Python 3.10. Probably it is better to use
# ctypes.c_double when returning a float rather than using ctypes.py_object
# and returning a PyFloat from inside the jitted function (i.e. let ctypes
# handle the conversion from double to PyFloat).
if signature.ret_type == ctypes.py_object:
FUNCTYPE = ctypes.PYFUNCTYPE
else:
FUNCTYPE = ctypes.CFUNCTYPE
cfunc = FUNCTYPE(signature.ret_type, *arg_ctypes)(fptr)
return cfunc

View File

@ -0,0 +1,314 @@
"""
Maple code printer
The MapleCodePrinter converts single SymPy expressions into single
Maple expressions, using the functions defined in the Maple objects where possible.
FIXME: This module is still under actively developed. Some functions may be not completed.
"""
from sympy.core import S
from sympy.core.numbers import Integer, IntegerConstant, equal_valued
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence, PRECEDENCE
import sympy
_known_func_same_name = (
'sin', 'cos', 'tan', 'sec', 'csc', 'cot', 'sinh', 'cosh', 'tanh', 'sech',
'csch', 'coth', 'exp', 'floor', 'factorial', 'bernoulli', 'euler',
'fibonacci', 'gcd', 'lcm', 'conjugate', 'Ci', 'Chi', 'Ei', 'Li', 'Si', 'Shi',
'erf', 'erfc', 'harmonic', 'LambertW',
'sqrt', # For automatic rewrites
)
known_functions = {
# SymPy -> Maple
'Abs': 'abs',
'log': 'ln',
'asin': 'arcsin',
'acos': 'arccos',
'atan': 'arctan',
'asec': 'arcsec',
'acsc': 'arccsc',
'acot': 'arccot',
'asinh': 'arcsinh',
'acosh': 'arccosh',
'atanh': 'arctanh',
'asech': 'arcsech',
'acsch': 'arccsch',
'acoth': 'arccoth',
'ceiling': 'ceil',
'Max' : 'max',
'Min' : 'min',
'factorial2': 'doublefactorial',
'RisingFactorial': 'pochhammer',
'besseli': 'BesselI',
'besselj': 'BesselJ',
'besselk': 'BesselK',
'bessely': 'BesselY',
'hankelh1': 'HankelH1',
'hankelh2': 'HankelH2',
'airyai': 'AiryAi',
'airybi': 'AiryBi',
'appellf1': 'AppellF1',
'fresnelc': 'FresnelC',
'fresnels': 'FresnelS',
'lerchphi' : 'LerchPhi',
}
for _func in _known_func_same_name:
known_functions[_func] = _func
number_symbols = {
# SymPy -> Maple
S.Pi: 'Pi',
S.Exp1: 'exp(1)',
S.Catalan: 'Catalan',
S.EulerGamma: 'gamma',
S.GoldenRatio: '(1/2 + (1/2)*sqrt(5))'
}
spec_relational_ops = {
# SymPy -> Maple
'==': '=',
'!=': '<>'
}
not_supported_symbol = [
S.ComplexInfinity
]
class MapleCodePrinter(CodePrinter):
"""
Printer which converts a SymPy expression into a maple code.
"""
printmethod = "_maple"
language = "maple"
_operators = {
'and': 'and',
'or': 'or',
'not': 'not ',
}
_default_settings = dict(CodePrinter._default_settings, **{
'inline': True,
'allow_unknown_functions': True,
})
def __init__(self, settings=None):
if settings is None:
settings = {}
super().__init__(settings)
self.known_functions = dict(known_functions)
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
def _get_statement(self, codestring):
return "%s;" % codestring
def _get_comment(self, text):
return "# {}".format(text)
def _declare_number_const(self, name, value):
return "{} := {};".format(name,
value.evalf(self._settings['precision']))
def _format_code(self, lines):
return lines
def _print_tuple(self, expr):
return self._print(list(expr))
def _print_Tuple(self, expr):
return self._print(list(expr))
def _print_Assignment(self, expr):
lhs = self._print(expr.lhs)
rhs = self._print(expr.rhs)
return "{lhs} := {rhs}".format(lhs=lhs, rhs=rhs)
def _print_Pow(self, expr, **kwargs):
PREC = precedence(expr)
if equal_valued(expr.exp, -1):
return '1/%s' % (self.parenthesize(expr.base, PREC))
elif equal_valued(expr.exp, 0.5):
return 'sqrt(%s)' % self._print(expr.base)
elif equal_valued(expr.exp, -0.5):
return '1/sqrt(%s)' % self._print(expr.base)
else:
return '{base}^{exp}'.format(
base=self.parenthesize(expr.base, PREC),
exp=self.parenthesize(expr.exp, PREC))
def _print_Piecewise(self, expr):
if (expr.args[-1].cond is not True) and (expr.args[-1].cond != S.BooleanTrue):
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
_coup_list = [
("{c}, {e}".format(c=self._print(c),
e=self._print(e)) if c is not True and c is not S.BooleanTrue else "{e}".format(
e=self._print(e)))
for e, c in expr.args]
_inbrace = ', '.join(_coup_list)
return 'piecewise({_inbrace})'.format(_inbrace=_inbrace)
def _print_Rational(self, expr):
p, q = int(expr.p), int(expr.q)
return "{p}/{q}".format(p=str(p), q=str(q))
def _print_Relational(self, expr):
PREC=precedence(expr)
lhs_code = self.parenthesize(expr.lhs, PREC)
rhs_code = self.parenthesize(expr.rhs, PREC)
op = expr.rel_op
if op in spec_relational_ops:
op = spec_relational_ops[op]
return "{lhs} {rel_op} {rhs}".format(lhs=lhs_code, rel_op=op, rhs=rhs_code)
def _print_NumberSymbol(self, expr):
return number_symbols[expr]
def _print_NegativeInfinity(self, expr):
return '-infinity'
def _print_Infinity(self, expr):
return 'infinity'
def _print_Idx(self, expr):
return self._print(expr.label)
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
def _print_bool(self, expr):
return 'true' if expr else 'false'
def _print_NaN(self, expr):
return 'undefined'
def _get_matrix(self, expr, sparse=False):
if S.Zero in expr.shape:
_strM = 'Matrix([], storage = {storage})'.format(
storage='sparse' if sparse else 'rectangular')
else:
_strM = 'Matrix({list}, storage = {storage})'.format(
list=self._print(expr.tolist()),
storage='sparse' if sparse else 'rectangular')
return _strM
def _print_MatrixElement(self, expr):
return "{parent}[{i_maple}, {j_maple}]".format(
parent=self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
i_maple=self._print(expr.i + 1),
j_maple=self._print(expr.j + 1))
def _print_MatrixBase(self, expr):
return self._get_matrix(expr, sparse=False)
def _print_SparseRepMatrix(self, expr):
return self._get_matrix(expr, sparse=True)
def _print_Identity(self, expr):
if isinstance(expr.rows, (Integer, IntegerConstant)):
return self._print(sympy.SparseMatrix(expr))
else:
return "Matrix({var_size}, shape = identity)".format(var_size=self._print(expr.rows))
def _print_MatMul(self, expr):
PREC=precedence(expr)
_fact_list = list(expr.args)
_const = None
if not isinstance(_fact_list[0], (sympy.MatrixBase, sympy.MatrixExpr,
sympy.MatrixSlice, sympy.MatrixSymbol)):
_const, _fact_list = _fact_list[0], _fact_list[1:]
if _const is None or _const == 1:
return '.'.join(self.parenthesize(_m, PREC) for _m in _fact_list)
else:
return '{c}*{m}'.format(c=_const, m='.'.join(self.parenthesize(_m, PREC) for _m in _fact_list))
def _print_MatPow(self, expr):
# This function requires LinearAlgebra Function in Maple
return 'MatrixPower({A}, {n})'.format(A=self._print(expr.base), n=self._print(expr.exp))
def _print_HadamardProduct(self, expr):
PREC = precedence(expr)
_fact_list = list(expr.args)
return '*'.join(self.parenthesize(_m, PREC) for _m in _fact_list)
def _print_Derivative(self, expr):
_f, (_var, _order) = expr.args
if _order != 1:
_second_arg = '{var}${order}'.format(var=self._print(_var),
order=self._print(_order))
else:
_second_arg = '{var}'.format(var=self._print(_var))
return 'diff({func_expr}, {sec_arg})'.format(func_expr=self._print(_f), sec_arg=_second_arg)
def maple_code(expr, assign_to=None, **settings):
r"""Converts ``expr`` to a string of Maple code.
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This can be helpful for
expressions that generate multi-line statements.
precision : integer, optional
The precision for numbers such as pi [default=16].
user_functions : dict, optional
A dictionary where keys are ``FunctionClass`` instances and values are
their string representations. Alternatively, the dictionary value can
be a list of tuples i.e. [(argument_test, cfunction_string)]. See
below for examples.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
inline: bool, optional
If True, we try to create single-statement code instead of multiple
statements. [default=True].
"""
return MapleCodePrinter(settings).doprint(expr, assign_to)
def print_maple_code(expr, **settings):
"""Prints the Maple representation of the given expression.
See :func:`maple_code` for the meaning of the optional arguments.
Examples
========
>>> from sympy import print_maple_code, symbols
>>> x, y = symbols('x y')
>>> print_maple_code(x, assign_to=y)
y := x
"""
print(maple_code(expr, **settings))

View File

@ -0,0 +1,350 @@
"""
Mathematica code printer
"""
from __future__ import annotations
from typing import Any
from sympy.core import Basic, Expr, Float
from sympy.core.sorting import default_sort_key
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence
# Used in MCodePrinter._print_Function(self)
known_functions = {
"exp": [(lambda x: True, "Exp")],
"log": [(lambda x: True, "Log")],
"sin": [(lambda x: True, "Sin")],
"cos": [(lambda x: True, "Cos")],
"tan": [(lambda x: True, "Tan")],
"cot": [(lambda x: True, "Cot")],
"sec": [(lambda x: True, "Sec")],
"csc": [(lambda x: True, "Csc")],
"asin": [(lambda x: True, "ArcSin")],
"acos": [(lambda x: True, "ArcCos")],
"atan": [(lambda x: True, "ArcTan")],
"acot": [(lambda x: True, "ArcCot")],
"asec": [(lambda x: True, "ArcSec")],
"acsc": [(lambda x: True, "ArcCsc")],
"atan2": [(lambda *x: True, "ArcTan")],
"sinh": [(lambda x: True, "Sinh")],
"cosh": [(lambda x: True, "Cosh")],
"tanh": [(lambda x: True, "Tanh")],
"coth": [(lambda x: True, "Coth")],
"sech": [(lambda x: True, "Sech")],
"csch": [(lambda x: True, "Csch")],
"asinh": [(lambda x: True, "ArcSinh")],
"acosh": [(lambda x: True, "ArcCosh")],
"atanh": [(lambda x: True, "ArcTanh")],
"acoth": [(lambda x: True, "ArcCoth")],
"asech": [(lambda x: True, "ArcSech")],
"acsch": [(lambda x: True, "ArcCsch")],
"sinc": [(lambda x: True, "Sinc")],
"conjugate": [(lambda x: True, "Conjugate")],
"Max": [(lambda *x: True, "Max")],
"Min": [(lambda *x: True, "Min")],
"erf": [(lambda x: True, "Erf")],
"erf2": [(lambda *x: True, "Erf")],
"erfc": [(lambda x: True, "Erfc")],
"erfi": [(lambda x: True, "Erfi")],
"erfinv": [(lambda x: True, "InverseErf")],
"erfcinv": [(lambda x: True, "InverseErfc")],
"erf2inv": [(lambda *x: True, "InverseErf")],
"expint": [(lambda *x: True, "ExpIntegralE")],
"Ei": [(lambda x: True, "ExpIntegralEi")],
"fresnelc": [(lambda x: True, "FresnelC")],
"fresnels": [(lambda x: True, "FresnelS")],
"gamma": [(lambda x: True, "Gamma")],
"uppergamma": [(lambda *x: True, "Gamma")],
"polygamma": [(lambda *x: True, "PolyGamma")],
"loggamma": [(lambda x: True, "LogGamma")],
"beta": [(lambda *x: True, "Beta")],
"Ci": [(lambda x: True, "CosIntegral")],
"Si": [(lambda x: True, "SinIntegral")],
"Chi": [(lambda x: True, "CoshIntegral")],
"Shi": [(lambda x: True, "SinhIntegral")],
"li": [(lambda x: True, "LogIntegral")],
"factorial": [(lambda x: True, "Factorial")],
"factorial2": [(lambda x: True, "Factorial2")],
"subfactorial": [(lambda x: True, "Subfactorial")],
"catalan": [(lambda x: True, "CatalanNumber")],
"harmonic": [(lambda *x: True, "HarmonicNumber")],
"lucas": [(lambda x: True, "LucasL")],
"RisingFactorial": [(lambda *x: True, "Pochhammer")],
"FallingFactorial": [(lambda *x: True, "FactorialPower")],
"laguerre": [(lambda *x: True, "LaguerreL")],
"assoc_laguerre": [(lambda *x: True, "LaguerreL")],
"hermite": [(lambda *x: True, "HermiteH")],
"jacobi": [(lambda *x: True, "JacobiP")],
"gegenbauer": [(lambda *x: True, "GegenbauerC")],
"chebyshevt": [(lambda *x: True, "ChebyshevT")],
"chebyshevu": [(lambda *x: True, "ChebyshevU")],
"legendre": [(lambda *x: True, "LegendreP")],
"assoc_legendre": [(lambda *x: True, "LegendreP")],
"mathieuc": [(lambda *x: True, "MathieuC")],
"mathieus": [(lambda *x: True, "MathieuS")],
"mathieucprime": [(lambda *x: True, "MathieuCPrime")],
"mathieusprime": [(lambda *x: True, "MathieuSPrime")],
"stieltjes": [(lambda x: True, "StieltjesGamma")],
"elliptic_e": [(lambda *x: True, "EllipticE")],
"elliptic_f": [(lambda *x: True, "EllipticE")],
"elliptic_k": [(lambda x: True, "EllipticK")],
"elliptic_pi": [(lambda *x: True, "EllipticPi")],
"zeta": [(lambda *x: True, "Zeta")],
"dirichlet_eta": [(lambda x: True, "DirichletEta")],
"riemann_xi": [(lambda x: True, "RiemannXi")],
"besseli": [(lambda *x: True, "BesselI")],
"besselj": [(lambda *x: True, "BesselJ")],
"besselk": [(lambda *x: True, "BesselK")],
"bessely": [(lambda *x: True, "BesselY")],
"hankel1": [(lambda *x: True, "HankelH1")],
"hankel2": [(lambda *x: True, "HankelH2")],
"airyai": [(lambda x: True, "AiryAi")],
"airybi": [(lambda x: True, "AiryBi")],
"airyaiprime": [(lambda x: True, "AiryAiPrime")],
"airybiprime": [(lambda x: True, "AiryBiPrime")],
"polylog": [(lambda *x: True, "PolyLog")],
"lerchphi": [(lambda *x: True, "LerchPhi")],
"gcd": [(lambda *x: True, "GCD")],
"lcm": [(lambda *x: True, "LCM")],
"jn": [(lambda *x: True, "SphericalBesselJ")],
"yn": [(lambda *x: True, "SphericalBesselY")],
"hyper": [(lambda *x: True, "HypergeometricPFQ")],
"meijerg": [(lambda *x: True, "MeijerG")],
"appellf1": [(lambda *x: True, "AppellF1")],
"DiracDelta": [(lambda x: True, "DiracDelta")],
"Heaviside": [(lambda x: True, "HeavisideTheta")],
"KroneckerDelta": [(lambda *x: True, "KroneckerDelta")],
"sqrt": [(lambda x: True, "Sqrt")], # For automatic rewrites
}
class MCodePrinter(CodePrinter):
"""A printer to convert Python expressions to
strings of the Wolfram's Mathematica code
"""
printmethod = "_mcode"
language = "Wolfram Language"
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
'precision': 15,
'user_functions': {},
})
_number_symbols: set[tuple[Expr, Float]] = set()
_not_supported: set[Basic] = set()
def __init__(self, settings={}):
"""Register function mappings supplied by user"""
CodePrinter.__init__(self, settings)
self.known_functions = dict(known_functions)
userfuncs = settings.get('user_functions', {}).copy()
for k, v in userfuncs.items():
if not isinstance(v, list):
userfuncs[k] = [(lambda *x: True, v)]
self.known_functions.update(userfuncs)
def _format_code(self, lines):
return lines
def _print_Pow(self, expr):
PREC = precedence(expr)
return '%s^%s' % (self.parenthesize(expr.base, PREC),
self.parenthesize(expr.exp, PREC))
def _print_Mul(self, expr):
PREC = precedence(expr)
c, nc = expr.args_cnc()
res = super()._print_Mul(expr.func(*c))
if nc:
res += '*'
res += '**'.join(self.parenthesize(a, PREC) for a in nc)
return res
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
return "{} {} {}".format(lhs_code, op, rhs_code)
# Primitive numbers
def _print_Zero(self, expr):
return '0'
def _print_One(self, expr):
return '1'
def _print_NegativeOne(self, expr):
return '-1'
def _print_Half(self, expr):
return '1/2'
def _print_ImaginaryUnit(self, expr):
return 'I'
# Infinity and invalid numbers
def _print_Infinity(self, expr):
return 'Infinity'
def _print_NegativeInfinity(self, expr):
return '-Infinity'
def _print_ComplexInfinity(self, expr):
return 'ComplexInfinity'
def _print_NaN(self, expr):
return 'Indeterminate'
# Mathematical constants
def _print_Exp1(self, expr):
return 'E'
def _print_Pi(self, expr):
return 'Pi'
def _print_GoldenRatio(self, expr):
return 'GoldenRatio'
def _print_TribonacciConstant(self, expr):
expanded = expr.expand(func=True)
PREC = precedence(expr)
return self.parenthesize(expanded, PREC)
def _print_EulerGamma(self, expr):
return 'EulerGamma'
def _print_Catalan(self, expr):
return 'Catalan'
def _print_list(self, expr):
return '{' + ', '.join(self.doprint(a) for a in expr) + '}'
_print_tuple = _print_list
_print_Tuple = _print_list
def _print_ImmutableDenseMatrix(self, expr):
return self.doprint(expr.tolist())
def _print_ImmutableSparseMatrix(self, expr):
def print_rule(pos, val):
return '{} -> {}'.format(
self.doprint((pos[0]+1, pos[1]+1)), self.doprint(val))
def print_data():
items = sorted(expr.todok().items(), key=default_sort_key)
return '{' + \
', '.join(print_rule(k, v) for k, v in items) + \
'}'
def print_dims():
return self.doprint(expr.shape)
return 'SparseArray[{}, {}]'.format(print_data(), print_dims())
def _print_ImmutableDenseNDimArray(self, expr):
return self.doprint(expr.tolist())
def _print_ImmutableSparseNDimArray(self, expr):
def print_string_list(string_list):
return '{' + ', '.join(a for a in string_list) + '}'
def to_mathematica_index(*args):
"""Helper function to change Python style indexing to
Pathematica indexing.
Python indexing (0, 1 ... n-1)
-> Mathematica indexing (1, 2 ... n)
"""
return tuple(i + 1 for i in args)
def print_rule(pos, val):
"""Helper function to print a rule of Mathematica"""
return '{} -> {}'.format(self.doprint(pos), self.doprint(val))
def print_data():
"""Helper function to print data part of Mathematica
sparse array.
It uses the fourth notation ``SparseArray[data,{d1,d2,...}]``
from
https://reference.wolfram.com/language/ref/SparseArray.html
``data`` must be formatted with rule.
"""
return print_string_list(
[print_rule(
to_mathematica_index(*(expr._get_tuple_index(key))),
value)
for key, value in sorted(expr._sparse_array.items())]
)
def print_dims():
"""Helper function to print dimensions part of Mathematica
sparse array.
It uses the fourth notation ``SparseArray[data,{d1,d2,...}]``
from
https://reference.wolfram.com/language/ref/SparseArray.html
"""
return self.doprint(expr.shape)
return 'SparseArray[{}, {}]'.format(print_data(), print_dims())
def _print_Function(self, expr):
if expr.func.__name__ in self.known_functions:
cond_mfunc = self.known_functions[expr.func.__name__]
for cond, mfunc in cond_mfunc:
if cond(*expr.args):
return "%s[%s]" % (mfunc, self.stringify(expr.args, ", "))
elif expr.func.__name__ in self._rewriteable_functions:
# Simple rewrite to supported function possible
target_f, required_fs = self._rewriteable_functions[expr.func.__name__]
if self._can_print(target_f) and all(self._can_print(f) for f in required_fs):
return self._print(expr.rewrite(target_f))
return expr.func.__name__ + "[%s]" % self.stringify(expr.args, ", ")
_print_MinMaxBase = _print_Function
def _print_LambertW(self, expr):
if len(expr.args) == 1:
return "ProductLog[{}]".format(self._print(expr.args[0]))
return "ProductLog[{}, {}]".format(
self._print(expr.args[1]), self._print(expr.args[0]))
def _print_Integral(self, expr):
if len(expr.variables) == 1 and not expr.limits[0][1:]:
args = [expr.args[0], expr.variables[0]]
else:
args = expr.args
return "Hold[Integrate[" + ', '.join(self.doprint(a) for a in args) + "]]"
def _print_Sum(self, expr):
return "Hold[Sum[" + ', '.join(self.doprint(a) for a in expr.args) + "]]"
def _print_Derivative(self, expr):
dexpr = expr.expr
dvars = [i[0] if i[1] == 1 else i for i in expr.variable_count]
return "Hold[D[" + ', '.join(self.doprint(a) for a in [dexpr] + dvars) + "]]"
def _get_comment(self, text):
return "(* {} *)".format(text)
def mathematica_code(expr, **settings):
r"""Converts an expr to a string of the Wolfram Mathematica code
Examples
========
>>> from sympy import mathematica_code as mcode, symbols, sin
>>> x = symbols('x')
>>> mcode(sin(x).series(x).removeO())
'(1/120)*x^5 - 1/6*x^3 + x'
"""
return MCodePrinter(settings).doprint(expr)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,510 @@
from sympy.core import S
from sympy.core.function import Lambda
from sympy.core.power import Pow
from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits, ArrayPrinter
from .codeprinter import CodePrinter
_not_in_numpy = 'erf erfc factorial gamma loggamma'.split()
_in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy]
_known_functions_numpy = dict(_in_numpy, **{
'acos': 'arccos',
'acosh': 'arccosh',
'asin': 'arcsin',
'asinh': 'arcsinh',
'atan': 'arctan',
'atan2': 'arctan2',
'atanh': 'arctanh',
'exp2': 'exp2',
'sign': 'sign',
'logaddexp': 'logaddexp',
'logaddexp2': 'logaddexp2',
'isnan': 'isnan'
})
_known_constants_numpy = {
'Exp1': 'e',
'Pi': 'pi',
'EulerGamma': 'euler_gamma',
'NaN': 'nan',
'Infinity': 'inf',
}
_numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()}
_numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()}
class NumPyPrinter(ArrayPrinter, PythonCodePrinter):
"""
Numpy printer which handles vectorized piecewise functions,
logical operators, etc.
"""
_module = 'numpy'
_kf = _numpy_known_functions
_kc = _numpy_known_constants
def __init__(self, settings=None):
"""
`settings` is passed to CodePrinter.__init__()
`module` specifies the array module to use, currently 'NumPy', 'CuPy'
or 'JAX'.
"""
self.language = "Python with {}".format(self._module)
self.printmethod = "_{}code".format(self._module)
self._kf = {**PythonCodePrinter._kf, **self._kf}
super().__init__(settings=settings)
def _print_seq(self, seq):
"General sequence printer: converts to tuple"
# Print tuples here instead of lists because numba supports
# tuples in nopython mode.
delimiter=', '
return '({},)'.format(delimiter.join(self._print(item) for item in seq))
def _print_NegativeInfinity(self, expr):
return '-' + self._print(S.Infinity)
def _print_MatMul(self, expr):
"Matrix multiplication printer"
if expr.as_coeff_matrices()[0] is not S.One:
expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])]
return '({})'.format(').dot('.join(self._print(i) for i in expr_list))
return '({})'.format(').dot('.join(self._print(i) for i in expr.args))
def _print_MatPow(self, expr):
"Matrix power printer"
return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'),
self._print(expr.args[0]), self._print(expr.args[1]))
def _print_Inverse(self, expr):
"Matrix inverse printer"
return '{}({})'.format(self._module_format(self._module + '.linalg.inv'),
self._print(expr.args[0]))
def _print_DotProduct(self, expr):
# DotProduct allows any shape order, but numpy.dot does matrix
# multiplication, so we have to make sure it gets 1 x n by n x 1.
arg1, arg2 = expr.args
if arg1.shape[0] != 1:
arg1 = arg1.T
if arg2.shape[1] != 1:
arg2 = arg2.T
return "%s(%s, %s)" % (self._module_format(self._module + '.dot'),
self._print(arg1),
self._print(arg2))
def _print_MatrixSolve(self, expr):
return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'),
self._print(expr.matrix),
self._print(expr.vector))
def _print_ZeroMatrix(self, expr):
return '{}({})'.format(self._module_format(self._module + '.zeros'),
self._print(expr.shape))
def _print_OneMatrix(self, expr):
return '{}({})'.format(self._module_format(self._module + '.ones'),
self._print(expr.shape))
def _print_FunctionMatrix(self, expr):
from sympy.abc import i, j
lamda = expr.lamda
if not isinstance(lamda, Lambda):
lamda = Lambda((i, j), lamda(i, j))
return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'),
', '.join(self._print(arg) for arg in lamda.args[0]),
self._print(lamda.args[1]), self._print(expr.shape))
def _print_HadamardProduct(self, expr):
func = self._module_format(self._module + '.multiply')
return ''.join('{}({}, '.format(func, self._print(arg)) \
for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
')' * (len(expr.args) - 1))
def _print_KroneckerProduct(self, expr):
func = self._module_format(self._module + '.kron')
return ''.join('{}({}, '.format(func, self._print(arg)) \
for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
')' * (len(expr.args) - 1))
def _print_Adjoint(self, expr):
return '{}({}({}))'.format(
self._module_format(self._module + '.conjugate'),
self._module_format(self._module + '.transpose'),
self._print(expr.args[0]))
def _print_DiagonalOf(self, expr):
vect = '{}({})'.format(
self._module_format(self._module + '.diag'),
self._print(expr.arg))
return '{}({}, (-1, 1))'.format(
self._module_format(self._module + '.reshape'), vect)
def _print_DiagMatrix(self, expr):
return '{}({})'.format(self._module_format(self._module + '.diagflat'),
self._print(expr.args[0]))
def _print_DiagonalMatrix(self, expr):
return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'),
self._print(expr.arg), self._module_format(self._module + '.eye'),
self._print(expr.shape[0]), self._print(expr.shape[1]))
def _print_Piecewise(self, expr):
"Piecewise function printer"
from sympy.logic.boolalg import ITE, simplify_logic
def print_cond(cond):
""" Problem having an ITE in the cond. """
if cond.has(ITE):
return self._print(simplify_logic(cond))
else:
return self._print(cond)
exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args))
conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args))
# If [default_value, True] is a (expr, cond) sequence in a Piecewise object
# it will behave the same as passing the 'default' kwarg to select()
# *as long as* it is the last element in expr.args.
# If this is not the case, it may be triggered prematurely.
return '{}({}, {}, default={})'.format(
self._module_format(self._module + '.select'), conds, exprs,
self._print(S.NaN))
def _print_Relational(self, expr):
"Relational printer for Equality and Unequality"
op = {
'==' :'equal',
'!=' :'not_equal',
'<' :'less',
'<=' :'less_equal',
'>' :'greater',
'>=' :'greater_equal',
}
if expr.rel_op in op:
lhs = self._print(expr.lhs)
rhs = self._print(expr.rhs)
return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]),
lhs=lhs, rhs=rhs)
return super()._print_Relational(expr)
def _print_And(self, expr):
"Logical And printer"
# We have to override LambdaPrinter because it uses Python 'and' keyword.
# If LambdaPrinter didn't define it, we could use StrPrinter's
# version of the function and add 'logical_and' to NUMPY_TRANSLATIONS.
return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args))
def _print_Or(self, expr):
"Logical Or printer"
# We have to override LambdaPrinter because it uses Python 'or' keyword.
# If LambdaPrinter didn't define it, we could use StrPrinter's
# version of the function and add 'logical_or' to NUMPY_TRANSLATIONS.
return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args))
def _print_Not(self, expr):
"Logical Not printer"
# We have to override LambdaPrinter because it uses Python 'not' keyword.
# If LambdaPrinter didn't define it, we would still have to define our
# own because StrPrinter doesn't define it.
return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args))
def _print_Pow(self, expr, rational=False):
# XXX Workaround for negative integer power error
if expr.exp.is_integer and expr.exp.is_negative:
expr = Pow(expr.base, expr.exp.evalf(), evaluate=False)
return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt')
def _print_Min(self, expr):
return '{}({}.asarray([{}]), axis=0)'.format(self._module_format(self._module + '.amin'), self._module_format(self._module), ','.join(self._print(i) for i in expr.args))
def _print_Max(self, expr):
return '{}({}.asarray([{}]), axis=0)'.format(self._module_format(self._module + '.amax'), self._module_format(self._module), ','.join(self._print(i) for i in expr.args))
def _print_arg(self, expr):
return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0]))
def _print_im(self, expr):
return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0]))
def _print_Mod(self, expr):
return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join(
(self._print(arg) for arg in expr.args)))
def _print_re(self, expr):
return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0]))
def _print_sinc(self, expr):
return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi))
def _print_MatrixBase(self, expr):
func = self.known_functions.get(expr.__class__.__name__, None)
if func is None:
func = self._module_format(self._module + '.array')
return "%s(%s)" % (func, self._print(expr.tolist()))
def _print_Identity(self, expr):
shape = expr.shape
if all(dim.is_Integer for dim in shape):
return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0]))
else:
raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices")
def _print_BlockMatrix(self, expr):
return '{}({})'.format(self._module_format(self._module + '.block'),
self._print(expr.args[0].tolist()))
def _print_NDimArray(self, expr):
if len(expr.shape) == 1:
return self._module + '.array(' + self._print(expr.args[0]) + ')'
if len(expr.shape) == 2:
return self._print(expr.tomatrix())
# Should be possible to extend to more dimensions
return super()._print_not_supported(self, expr)
_add = "add"
_einsum = "einsum"
_transpose = "transpose"
_ones = "ones"
_zeros = "zeros"
_print_lowergamma = CodePrinter._print_not_supported
_print_uppergamma = CodePrinter._print_not_supported
_print_fresnelc = CodePrinter._print_not_supported
_print_fresnels = CodePrinter._print_not_supported
for func in _numpy_known_functions:
setattr(NumPyPrinter, f'_print_{func}', _print_known_func)
for const in _numpy_known_constants:
setattr(NumPyPrinter, f'_print_{const}', _print_known_const)
_known_functions_scipy_special = {
'Ei': 'expi',
'erf': 'erf',
'erfc': 'erfc',
'besselj': 'jv',
'bessely': 'yv',
'besseli': 'iv',
'besselk': 'kv',
'cosm1': 'cosm1',
'powm1': 'powm1',
'factorial': 'factorial',
'gamma': 'gamma',
'loggamma': 'gammaln',
'digamma': 'psi',
'polygamma': 'polygamma',
'RisingFactorial': 'poch',
'jacobi': 'eval_jacobi',
'gegenbauer': 'eval_gegenbauer',
'chebyshevt': 'eval_chebyt',
'chebyshevu': 'eval_chebyu',
'legendre': 'eval_legendre',
'hermite': 'eval_hermite',
'laguerre': 'eval_laguerre',
'assoc_laguerre': 'eval_genlaguerre',
'beta': 'beta',
'LambertW' : 'lambertw',
}
_known_constants_scipy_constants = {
'GoldenRatio': 'golden_ratio',
'Pi': 'pi',
}
_scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()}
_scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()}
class SciPyPrinter(NumPyPrinter):
_kf = {**NumPyPrinter._kf, **_scipy_known_functions}
_kc = {**NumPyPrinter._kc, **_scipy_known_constants}
def __init__(self, settings=None):
super().__init__(settings=settings)
self.language = "Python with SciPy and NumPy"
def _print_SparseRepMatrix(self, expr):
i, j, data = [], [], []
for (r, c), v in expr.todok().items():
i.append(r)
j.append(c)
data.append(v)
return "{name}(({data}, ({i}, {j})), shape={shape})".format(
name=self._module_format('scipy.sparse.coo_matrix'),
data=data, i=i, j=j, shape=expr.shape
)
_print_ImmutableSparseMatrix = _print_SparseRepMatrix
# SciPy's lpmv has a different order of arguments from assoc_legendre
def _print_assoc_legendre(self, expr):
return "{0}({2}, {1}, {3})".format(
self._module_format('scipy.special.lpmv'),
self._print(expr.args[0]),
self._print(expr.args[1]),
self._print(expr.args[2]))
def _print_lowergamma(self, expr):
return "{0}({2})*{1}({2}, {3})".format(
self._module_format('scipy.special.gamma'),
self._module_format('scipy.special.gammainc'),
self._print(expr.args[0]),
self._print(expr.args[1]))
def _print_uppergamma(self, expr):
return "{0}({2})*{1}({2}, {3})".format(
self._module_format('scipy.special.gamma'),
self._module_format('scipy.special.gammaincc'),
self._print(expr.args[0]),
self._print(expr.args[1]))
def _print_betainc(self, expr):
betainc = self._module_format('scipy.special.betainc')
beta = self._module_format('scipy.special.beta')
args = [self._print(arg) for arg in expr.args]
return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \
* {beta}({args[0]}, {args[1]})"
def _print_betainc_regularized(self, expr):
return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format(
self._module_format('scipy.special.betainc'),
self._print(expr.args[0]),
self._print(expr.args[1]),
self._print(expr.args[2]),
self._print(expr.args[3]))
def _print_fresnels(self, expr):
return "{}({})[0]".format(
self._module_format("scipy.special.fresnel"),
self._print(expr.args[0]))
def _print_fresnelc(self, expr):
return "{}({})[1]".format(
self._module_format("scipy.special.fresnel"),
self._print(expr.args[0]))
def _print_airyai(self, expr):
return "{}({})[0]".format(
self._module_format("scipy.special.airy"),
self._print(expr.args[0]))
def _print_airyaiprime(self, expr):
return "{}({})[1]".format(
self._module_format("scipy.special.airy"),
self._print(expr.args[0]))
def _print_airybi(self, expr):
return "{}({})[2]".format(
self._module_format("scipy.special.airy"),
self._print(expr.args[0]))
def _print_airybiprime(self, expr):
return "{}({})[3]".format(
self._module_format("scipy.special.airy"),
self._print(expr.args[0]))
def _print_bernoulli(self, expr):
# scipy's bernoulli is inconsistent with SymPy's so rewrite
return self._print(expr._eval_rewrite_as_zeta(*expr.args))
def _print_harmonic(self, expr):
return self._print(expr._eval_rewrite_as_zeta(*expr.args))
def _print_Integral(self, e):
integration_vars, limits = _unpack_integral_limits(e)
if len(limits) == 1:
# nicer (but not necessary) to prefer quad over nquad for 1D case
module_str = self._module_format("scipy.integrate.quad")
limit_str = "%s, %s" % tuple(map(self._print, limits[0]))
else:
module_str = self._module_format("scipy.integrate.nquad")
limit_str = "({})".format(", ".join(
"(%s, %s)" % tuple(map(self._print, l)) for l in limits))
return "{}(lambda {}: {}, {})[0]".format(
module_str,
", ".join(map(self._print, integration_vars)),
self._print(e.args[0]),
limit_str)
def _print_Si(self, expr):
return "{}({})[0]".format(
self._module_format("scipy.special.sici"),
self._print(expr.args[0]))
def _print_Ci(self, expr):
return "{}({})[1]".format(
self._module_format("scipy.special.sici"),
self._print(expr.args[0]))
for func in _scipy_known_functions:
setattr(SciPyPrinter, f'_print_{func}', _print_known_func)
for const in _scipy_known_constants:
setattr(SciPyPrinter, f'_print_{const}', _print_known_const)
_cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()}
_cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()}
class CuPyPrinter(NumPyPrinter):
"""
CuPy printer which handles vectorized piecewise functions,
logical operators, etc.
"""
_module = 'cupy'
_kf = _cupy_known_functions
_kc = _cupy_known_constants
def __init__(self, settings=None):
super().__init__(settings=settings)
for func in _cupy_known_functions:
setattr(CuPyPrinter, f'_print_{func}', _print_known_func)
for const in _cupy_known_constants:
setattr(CuPyPrinter, f'_print_{const}', _print_known_const)
_jax_known_functions = {k: 'jax.numpy.' + v for k, v in _known_functions_numpy.items()}
_jax_known_constants = {k: 'jax.numpy.' + v for k, v in _known_constants_numpy.items()}
class JaxPrinter(NumPyPrinter):
"""
JAX printer which handles vectorized piecewise functions,
logical operators, etc.
"""
_module = "jax.numpy"
_kf = _jax_known_functions
_kc = _jax_known_constants
def __init__(self, settings=None):
super().__init__(settings=settings)
self.printmethod = '_jaxcode'
# These need specific override to allow for the lack of "jax.numpy.reduce"
def _print_And(self, expr):
"Logical And printer"
return "{}({}.asarray([{}]), axis=0)".format(
self._module_format(self._module + ".all"),
self._module_format(self._module),
",".join(self._print(i) for i in expr.args),
)
def _print_Or(self, expr):
"Logical Or printer"
return "{}({}.asarray([{}]), axis=0)".format(
self._module_format(self._module + ".any"),
self._module_format(self._module),
",".join(self._print(i) for i in expr.args),
)
for func in _jax_known_functions:
setattr(JaxPrinter, f'_print_{func}', _print_known_func)
for const in _jax_known_constants:
setattr(JaxPrinter, f'_print_{const}', _print_known_const)

View File

@ -0,0 +1,715 @@
"""
Octave (and Matlab) code printer
The `OctaveCodePrinter` converts SymPy expressions into Octave expressions.
It uses a subset of the Octave language for Matlab compatibility.
A complete code generator, which uses `octave_code` extensively, can be found
in `sympy.utilities.codegen`. The `codegen` module can be used to generate
complete source code files.
"""
from __future__ import annotations
from typing import Any
from sympy.core import Mul, Pow, S, Rational
from sympy.core.mul import _keep_coeff
from sympy.core.numbers import equal_valued
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence, PRECEDENCE
from re import search
# List of known functions. First, those that have the same name in
# SymPy and Octave. This is almost certainly incomplete!
known_fcns_src1 = ["sin", "cos", "tan", "cot", "sec", "csc",
"asin", "acos", "acot", "atan", "atan2", "asec", "acsc",
"sinh", "cosh", "tanh", "coth", "csch", "sech",
"asinh", "acosh", "atanh", "acoth", "asech", "acsch",
"erfc", "erfi", "erf", "erfinv", "erfcinv",
"besseli", "besselj", "besselk", "bessely",
"bernoulli", "beta", "euler", "exp", "factorial", "floor",
"fresnelc", "fresnels", "gamma", "harmonic", "log",
"polylog", "sign", "zeta", "legendre"]
# These functions have different names ("SymPy": "Octave"), more
# generally a mapping to (argument_conditions, octave_function).
known_fcns_src2 = {
"Abs": "abs",
"arg": "angle", # arg/angle ok in Octave but only angle in Matlab
"binomial": "bincoeff",
"ceiling": "ceil",
"chebyshevu": "chebyshevU",
"chebyshevt": "chebyshevT",
"Chi": "coshint",
"Ci": "cosint",
"conjugate": "conj",
"DiracDelta": "dirac",
"Heaviside": "heaviside",
"im": "imag",
"laguerre": "laguerreL",
"LambertW": "lambertw",
"li": "logint",
"loggamma": "gammaln",
"Max": "max",
"Min": "min",
"Mod": "mod",
"polygamma": "psi",
"re": "real",
"RisingFactorial": "pochhammer",
"Shi": "sinhint",
"Si": "sinint",
}
class OctaveCodePrinter(CodePrinter):
"""
A printer to convert expressions to strings of Octave/Matlab code.
"""
printmethod = "_octave"
language = "Octave"
_operators = {
'and': '&',
'or': '|',
'not': '~',
}
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
'precision': 17,
'user_functions': {},
'contract': True,
'inline': True,
})
# Note: contract is for expressing tensors as loops (if True), or just
# assignment (if False). FIXME: this should be looked a more carefully
# for Octave.
def __init__(self, settings={}):
super().__init__(settings)
self.known_functions = dict(zip(known_fcns_src1, known_fcns_src1))
self.known_functions.update(dict(known_fcns_src2))
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
def _rate_index_position(self, p):
return p*5
def _get_statement(self, codestring):
return "%s;" % codestring
def _get_comment(self, text):
return "% {}".format(text)
def _declare_number_const(self, name, value):
return "{} = {};".format(name, value)
def _format_code(self, lines):
return self.indent_code(lines)
def _traverse_matrix_indices(self, mat):
# Octave uses Fortran order (column-major)
rows, cols = mat.shape
return ((i, j) for j in range(cols) for i in range(rows))
def _get_loop_opening_ending(self, indices):
open_lines = []
close_lines = []
for i in indices:
# Octave arrays start at 1 and end at dimension
var, start, stop = map(self._print,
[i.label, i.lower + 1, i.upper + 1])
open_lines.append("for %s = %s:%s" % (var, start, stop))
close_lines.append("end")
return open_lines, close_lines
def _print_Mul(self, expr):
# print complex numbers nicely in Octave
if (expr.is_number and expr.is_imaginary and
(S.ImaginaryUnit*expr).is_Integer):
return "%si" % self._print(-S.ImaginaryUnit*expr)
# cribbed from str.py
prec = precedence(expr)
c, e = expr.as_coeff_Mul()
if c < 0:
expr = _keep_coeff(-c, e)
sign = "-"
else:
sign = ""
a = [] # items in the numerator
b = [] # items that are in the denominator (if any)
pow_paren = [] # Will collect all pow with more than one base element and exp = -1
if self.order not in ('old', 'none'):
args = expr.as_ordered_factors()
else:
# use make_args in case expr was something like -x -> x
args = Mul.make_args(expr)
# Gather args for numerator/denominator
for item in args:
if (item.is_commutative and item.is_Pow and item.exp.is_Rational
and item.exp.is_negative):
if item.exp != -1:
b.append(Pow(item.base, -item.exp, evaluate=False))
else:
if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160
pow_paren.append(item)
b.append(Pow(item.base, -item.exp))
elif item.is_Rational and item is not S.Infinity:
if item.p != 1:
a.append(Rational(item.p))
if item.q != 1:
b.append(Rational(item.q))
else:
a.append(item)
a = a or [S.One]
a_str = [self.parenthesize(x, prec) for x in a]
b_str = [self.parenthesize(x, prec) for x in b]
# To parenthesize Pow with exp = -1 and having more than one Symbol
for item in pow_paren:
if item.base in b:
b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)]
# from here it differs from str.py to deal with "*" and ".*"
def multjoin(a, a_str):
# here we probably are assuming the constants will come first
r = a_str[0]
for i in range(1, len(a)):
mulsym = '*' if a[i-1].is_number else '.*'
r = r + mulsym + a_str[i]
return r
if not b:
return sign + multjoin(a, a_str)
elif len(b) == 1:
divsym = '/' if b[0].is_number else './'
return sign + multjoin(a, a_str) + divsym + b_str[0]
else:
divsym = '/' if all(bi.is_number for bi in b) else './'
return (sign + multjoin(a, a_str) +
divsym + "(%s)" % multjoin(b, b_str))
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
return "{} {} {}".format(lhs_code, op, rhs_code)
def _print_Pow(self, expr):
powsymbol = '^' if all(x.is_number for x in expr.args) else '.^'
PREC = precedence(expr)
if equal_valued(expr.exp, 0.5):
return "sqrt(%s)" % self._print(expr.base)
if expr.is_commutative:
if equal_valued(expr.exp, -0.5):
sym = '/' if expr.base.is_number else './'
return "1" + sym + "sqrt(%s)" % self._print(expr.base)
if equal_valued(expr.exp, -1):
sym = '/' if expr.base.is_number else './'
return "1" + sym + "%s" % self.parenthesize(expr.base, PREC)
return '%s%s%s' % (self.parenthesize(expr.base, PREC), powsymbol,
self.parenthesize(expr.exp, PREC))
def _print_MatPow(self, expr):
PREC = precedence(expr)
return '%s^%s' % (self.parenthesize(expr.base, PREC),
self.parenthesize(expr.exp, PREC))
def _print_MatrixSolve(self, expr):
PREC = precedence(expr)
return "%s \\ %s" % (self.parenthesize(expr.matrix, PREC),
self.parenthesize(expr.vector, PREC))
def _print_Pi(self, expr):
return 'pi'
def _print_ImaginaryUnit(self, expr):
return "1i"
def _print_Exp1(self, expr):
return "exp(1)"
def _print_GoldenRatio(self, expr):
# FIXME: how to do better, e.g., for octave_code(2*GoldenRatio)?
#return self._print((1+sqrt(S(5)))/2)
return "(1+sqrt(5))/2"
def _print_Assignment(self, expr):
from sympy.codegen.ast import Assignment
from sympy.functions.elementary.piecewise import Piecewise
from sympy.tensor.indexed import IndexedBase
# Copied from codeprinter, but remove special MatrixSymbol treatment
lhs = expr.lhs
rhs = expr.rhs
# We special case assignments that take multiple lines
if not self._settings["inline"] and isinstance(expr.rhs, Piecewise):
# Here we modify Piecewise so each expression is now
# an Assignment, and then continue on the print.
expressions = []
conditions = []
for (e, c) in rhs.args:
expressions.append(Assignment(lhs, e))
conditions.append(c)
temp = Piecewise(*zip(expressions, conditions))
return self._print(temp)
if self._settings["contract"] and (lhs.has(IndexedBase) or
rhs.has(IndexedBase)):
# Here we check if there is looping to be done, and if so
# print the required loops.
return self._doprint_loops(rhs, lhs)
else:
lhs_code = self._print(lhs)
rhs_code = self._print(rhs)
return self._get_statement("%s = %s" % (lhs_code, rhs_code))
def _print_Infinity(self, expr):
return 'inf'
def _print_NegativeInfinity(self, expr):
return '-inf'
def _print_NaN(self, expr):
return 'NaN'
def _print_list(self, expr):
return '{' + ', '.join(self._print(a) for a in expr) + '}'
_print_tuple = _print_list
_print_Tuple = _print_list
_print_List = _print_list
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
def _print_bool(self, expr):
return str(expr).lower()
# Could generate quadrature code for definite Integrals?
#_print_Integral = _print_not_supported
def _print_MatrixBase(self, A):
# Handle zero dimensions:
if (A.rows, A.cols) == (0, 0):
return '[]'
elif S.Zero in A.shape:
return 'zeros(%s, %s)' % (A.rows, A.cols)
elif (A.rows, A.cols) == (1, 1):
# Octave does not distinguish between scalars and 1x1 matrices
return self._print(A[0, 0])
return "[%s]" % "; ".join(" ".join([self._print(a) for a in A[r, :]])
for r in range(A.rows))
def _print_SparseRepMatrix(self, A):
from sympy.matrices import Matrix
L = A.col_list();
# make row vectors of the indices and entries
I = Matrix([[k[0] + 1 for k in L]])
J = Matrix([[k[1] + 1 for k in L]])
AIJ = Matrix([[k[2] for k in L]])
return "sparse(%s, %s, %s, %s, %s)" % (self._print(I), self._print(J),
self._print(AIJ), A.rows, A.cols)
def _print_MatrixElement(self, expr):
return self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) \
+ '(%s, %s)' % (expr.i + 1, expr.j + 1)
def _print_MatrixSlice(self, expr):
def strslice(x, lim):
l = x[0] + 1
h = x[1]
step = x[2]
lstr = self._print(l)
hstr = 'end' if h == lim else self._print(h)
if step == 1:
if l == 1 and h == lim:
return ':'
if l == h:
return lstr
else:
return lstr + ':' + hstr
else:
return ':'.join((lstr, self._print(step), hstr))
return (self._print(expr.parent) + '(' +
strslice(expr.rowslice, expr.parent.shape[0]) + ', ' +
strslice(expr.colslice, expr.parent.shape[1]) + ')')
def _print_Indexed(self, expr):
inds = [ self._print(i) for i in expr.indices ]
return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds))
def _print_Idx(self, expr):
return self._print(expr.label)
def _print_KroneckerDelta(self, expr):
prec = PRECEDENCE["Pow"]
return "double(%s == %s)" % tuple(self.parenthesize(x, prec)
for x in expr.args)
def _print_HadamardProduct(self, expr):
return '.*'.join([self.parenthesize(arg, precedence(expr))
for arg in expr.args])
def _print_HadamardPower(self, expr):
PREC = precedence(expr)
return '.**'.join([
self.parenthesize(expr.base, PREC),
self.parenthesize(expr.exp, PREC)
])
def _print_Identity(self, expr):
shape = expr.shape
if len(shape) == 2 and shape[0] == shape[1]:
shape = [shape[0]]
s = ", ".join(self._print(n) for n in shape)
return "eye(" + s + ")"
def _print_lowergamma(self, expr):
# Octave implements regularized incomplete gamma function
return "(gammainc({1}, {0}).*gamma({0}))".format(
self._print(expr.args[0]), self._print(expr.args[1]))
def _print_uppergamma(self, expr):
return "(gammainc({1}, {0}, 'upper').*gamma({0}))".format(
self._print(expr.args[0]), self._print(expr.args[1]))
def _print_sinc(self, expr):
#Note: Divide by pi because Octave implements normalized sinc function.
return "sinc(%s)" % self._print(expr.args[0]/S.Pi)
def _print_hankel1(self, expr):
return "besselh(%s, 1, %s)" % (self._print(expr.order),
self._print(expr.argument))
def _print_hankel2(self, expr):
return "besselh(%s, 2, %s)" % (self._print(expr.order),
self._print(expr.argument))
# Note: as of 2015, Octave doesn't have spherical Bessel functions
def _print_jn(self, expr):
from sympy.functions import sqrt, besselj
x = expr.argument
expr2 = sqrt(S.Pi/(2*x))*besselj(expr.order + S.Half, x)
return self._print(expr2)
def _print_yn(self, expr):
from sympy.functions import sqrt, bessely
x = expr.argument
expr2 = sqrt(S.Pi/(2*x))*bessely(expr.order + S.Half, x)
return self._print(expr2)
def _print_airyai(self, expr):
return "airy(0, %s)" % self._print(expr.args[0])
def _print_airyaiprime(self, expr):
return "airy(1, %s)" % self._print(expr.args[0])
def _print_airybi(self, expr):
return "airy(2, %s)" % self._print(expr.args[0])
def _print_airybiprime(self, expr):
return "airy(3, %s)" % self._print(expr.args[0])
def _print_expint(self, expr):
mu, x = expr.args
if mu != 1:
return self._print_not_supported(expr)
return "expint(%s)" % self._print(x)
def _one_or_two_reversed_args(self, expr):
assert len(expr.args) <= 2
return '{name}({args})'.format(
name=self.known_functions[expr.__class__.__name__],
args=", ".join([self._print(x) for x in reversed(expr.args)])
)
_print_DiracDelta = _print_LambertW = _one_or_two_reversed_args
def _nested_binary_math_func(self, expr):
return '{name}({arg1}, {arg2})'.format(
name=self.known_functions[expr.__class__.__name__],
arg1=self._print(expr.args[0]),
arg2=self._print(expr.func(*expr.args[1:]))
)
_print_Max = _print_Min = _nested_binary_math_func
def _print_Piecewise(self, expr):
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
lines = []
if self._settings["inline"]:
# Express each (cond, expr) pair in a nested Horner form:
# (condition) .* (expr) + (not cond) .* (<others>)
# Expressions that result in multiple statements won't work here.
ecpairs = ["({0}).*({1}) + (~({0})).*(".format
(self._print(c), self._print(e))
for e, c in expr.args[:-1]]
elast = "%s" % self._print(expr.args[-1].expr)
pw = " ...\n".join(ecpairs) + elast + ")"*len(ecpairs)
# Note: current need these outer brackets for 2*pw. Would be
# nicer to teach parenthesize() to do this for us when needed!
return "(" + pw + ")"
else:
for i, (e, c) in enumerate(expr.args):
if i == 0:
lines.append("if (%s)" % self._print(c))
elif i == len(expr.args) - 1 and c == True:
lines.append("else")
else:
lines.append("elseif (%s)" % self._print(c))
code0 = self._print(e)
lines.append(code0)
if i == len(expr.args) - 1:
lines.append("end")
return "\n".join(lines)
def _print_zeta(self, expr):
if len(expr.args) == 1:
return "zeta(%s)" % self._print(expr.args[0])
else:
# Matlab two argument zeta is not equivalent to SymPy's
return self._print_not_supported(expr)
def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
# code mostly copied from ccode
if isinstance(code, str):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
tab = " "
inc_regex = ('^function ', '^if ', '^elseif ', '^else$', '^for ')
dec_regex = ('^end$', '^elseif ', '^else$')
# pre-strip left-space from the code
code = [ line.lstrip(' \t') for line in code ]
increase = [ int(any(search(re, line) for re in inc_regex))
for line in code ]
decrease = [ int(any(search(re, line) for re in dec_regex))
for line in code ]
pretty = []
level = 0
for n, line in enumerate(code):
if line in ('', '\n'):
pretty.append(line)
continue
level -= decrease[n]
pretty.append("%s%s" % (tab*level, line))
level += increase[n]
return pretty
def octave_code(expr, assign_to=None, **settings):
r"""Converts `expr` to a string of Octave (or Matlab) code.
The string uses a subset of the Octave language for Matlab compatibility.
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This can be helpful for
expressions that generate multi-line statements.
precision : integer, optional
The precision for numbers such as pi [default=16].
user_functions : dict, optional
A dictionary where keys are ``FunctionClass`` instances and values are
their string representations. Alternatively, the dictionary value can
be a list of tuples i.e. [(argument_test, cfunction_string)]. See
below for examples.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
inline: bool, optional
If True, we try to create single-statement code instead of multiple
statements. [default=True].
Examples
========
>>> from sympy import octave_code, symbols, sin, pi
>>> x = symbols('x')
>>> octave_code(sin(x).series(x).removeO())
'x.^5/120 - x.^3/6 + x'
>>> from sympy import Rational, ceiling
>>> x, y, tau = symbols("x, y, tau")
>>> octave_code((2*tau)**Rational(7, 2))
'8*sqrt(2)*tau.^(7/2)'
Note that element-wise (Hadamard) operations are used by default between
symbols. This is because its very common in Octave to write "vectorized"
code. It is harmless if the values are scalars.
>>> octave_code(sin(pi*x*y), assign_to="s")
's = sin(pi*x.*y);'
If you need a matrix product "*" or matrix power "^", you can specify the
symbol as a ``MatrixSymbol``.
>>> from sympy import Symbol, MatrixSymbol
>>> n = Symbol('n', integer=True, positive=True)
>>> A = MatrixSymbol('A', n, n)
>>> octave_code(3*pi*A**3)
'(3*pi)*A^3'
This class uses several rules to decide which symbol to use a product.
Pure numbers use "*", Symbols use ".*" and MatrixSymbols use "*".
A HadamardProduct can be used to specify componentwise multiplication ".*"
of two MatrixSymbols. There is currently there is no easy way to specify
scalar symbols, so sometimes the code might have some minor cosmetic
issues. For example, suppose x and y are scalars and A is a Matrix, then
while a human programmer might write "(x^2*y)*A^3", we generate:
>>> octave_code(x**2*y*A**3)
'(x.^2.*y)*A^3'
Matrices are supported using Octave inline notation. When using
``assign_to`` with matrices, the name can be specified either as a string
or as a ``MatrixSymbol``. The dimensions must align in the latter case.
>>> from sympy import Matrix, MatrixSymbol
>>> mat = Matrix([[x**2, sin(x), ceiling(x)]])
>>> octave_code(mat, assign_to='A')
'A = [x.^2 sin(x) ceil(x)];'
``Piecewise`` expressions are implemented with logical masking by default.
Alternatively, you can pass "inline=False" to use if-else conditionals.
Note that if the ``Piecewise`` lacks a default term, represented by
``(expr, True)`` then an error will be thrown. This is to prevent
generating an expression that may not evaluate to anything.
>>> from sympy import Piecewise
>>> pw = Piecewise((x + 1, x > 0), (x, True))
>>> octave_code(pw, assign_to=tau)
'tau = ((x > 0).*(x + 1) + (~(x > 0)).*(x));'
Note that any expression that can be generated normally can also exist
inside a Matrix:
>>> mat = Matrix([[x**2, pw, sin(x)]])
>>> octave_code(mat, assign_to='A')
'A = [x.^2 ((x > 0).*(x + 1) + (~(x > 0)).*(x)) sin(x)];'
Custom printing can be defined for certain types by passing a dictionary of
"type" : "function" to the ``user_functions`` kwarg. Alternatively, the
dictionary value can be a list of tuples i.e., [(argument_test,
cfunction_string)]. This can be used to call a custom Octave function.
>>> from sympy import Function
>>> f = Function('f')
>>> g = Function('g')
>>> custom_functions = {
... "f": "existing_octave_fcn",
... "g": [(lambda x: x.is_Matrix, "my_mat_fcn"),
... (lambda x: not x.is_Matrix, "my_fcn")]
... }
>>> mat = Matrix([[1, x]])
>>> octave_code(f(x) + g(x) + g(mat), user_functions=custom_functions)
'existing_octave_fcn(x) + my_fcn(x) + my_mat_fcn([1 x])'
Support for loops is provided through ``Indexed`` types. With
``contract=True`` these expressions will be turned into loops, whereas
``contract=False`` will just print the assignment expression that should be
looped over:
>>> from sympy import Eq, IndexedBase, Idx
>>> len_y = 5
>>> y = IndexedBase('y', shape=(len_y,))
>>> t = IndexedBase('t', shape=(len_y,))
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
>>> i = Idx('i', len_y-1)
>>> e = Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
>>> octave_code(e.rhs, assign_to=e.lhs, contract=False)
'Dy(i) = (y(i + 1) - y(i))./(t(i + 1) - t(i));'
"""
return OctaveCodePrinter(settings).doprint(expr, assign_to)
def print_octave_code(expr, **settings):
"""Prints the Octave (or Matlab) representation of the given expression.
See `octave_code` for the meaning of the optional arguments.
"""
print(octave_code(expr, **settings))

View File

@ -0,0 +1,174 @@
"""A module providing information about the necessity of brackets"""
# Default precedence values for some basic types
PRECEDENCE = {
"Lambda": 1,
"Xor": 10,
"Or": 20,
"And": 30,
"Relational": 35,
"Add": 40,
"Mul": 50,
"Pow": 60,
"Func": 70,
"Not": 100,
"Atom": 1000,
"BitwiseOr": 36,
"BitwiseXor": 37,
"BitwiseAnd": 38
}
# A dictionary assigning precedence values to certain classes. These values are
# treated like they were inherited, so not every single class has to be named
# here.
# Do not use this with printers other than StrPrinter
PRECEDENCE_VALUES = {
"Equivalent": PRECEDENCE["Xor"],
"Xor": PRECEDENCE["Xor"],
"Implies": PRECEDENCE["Xor"],
"Or": PRECEDENCE["Or"],
"And": PRECEDENCE["And"],
"Add": PRECEDENCE["Add"],
"Pow": PRECEDENCE["Pow"],
"Relational": PRECEDENCE["Relational"],
"Sub": PRECEDENCE["Add"],
"Not": PRECEDENCE["Not"],
"Function" : PRECEDENCE["Func"],
"NegativeInfinity": PRECEDENCE["Add"],
"MatAdd": PRECEDENCE["Add"],
"MatPow": PRECEDENCE["Pow"],
"MatrixSolve": PRECEDENCE["Mul"],
"Mod": PRECEDENCE["Mul"],
"TensAdd": PRECEDENCE["Add"],
# As soon as `TensMul` is a subclass of `Mul`, remove this:
"TensMul": PRECEDENCE["Mul"],
"HadamardProduct": PRECEDENCE["Mul"],
"HadamardPower": PRECEDENCE["Pow"],
"KroneckerProduct": PRECEDENCE["Mul"],
"Equality": PRECEDENCE["Mul"],
"Unequality": PRECEDENCE["Mul"],
}
# Sometimes it's not enough to assign a fixed precedence value to a
# class. Then a function can be inserted in this dictionary that takes
# an instance of this class as argument and returns the appropriate
# precedence value.
# Precedence functions
def precedence_Mul(item):
if item.could_extract_minus_sign():
return PRECEDENCE["Add"]
return PRECEDENCE["Mul"]
def precedence_Rational(item):
if item.p < 0:
return PRECEDENCE["Add"]
return PRECEDENCE["Mul"]
def precedence_Integer(item):
if item.p < 0:
return PRECEDENCE["Add"]
return PRECEDENCE["Atom"]
def precedence_Float(item):
if item < 0:
return PRECEDENCE["Add"]
return PRECEDENCE["Atom"]
def precedence_PolyElement(item):
if item.is_generator:
return PRECEDENCE["Atom"]
elif item.is_ground:
return precedence(item.coeff(1))
elif item.is_term:
return PRECEDENCE["Mul"]
else:
return PRECEDENCE["Add"]
def precedence_FracElement(item):
if item.denom == 1:
return precedence_PolyElement(item.numer)
else:
return PRECEDENCE["Mul"]
def precedence_UnevaluatedExpr(item):
return precedence(item.args[0]) - 0.5
PRECEDENCE_FUNCTIONS = {
"Integer": precedence_Integer,
"Mul": precedence_Mul,
"Rational": precedence_Rational,
"Float": precedence_Float,
"PolyElement": precedence_PolyElement,
"FracElement": precedence_FracElement,
"UnevaluatedExpr": precedence_UnevaluatedExpr,
}
def precedence(item):
"""Returns the precedence of a given object.
This is the precedence for StrPrinter.
"""
if hasattr(item, "precedence"):
return item.precedence
if not isinstance(item, type):
for i in type(item).mro():
n = i.__name__
if n in PRECEDENCE_FUNCTIONS:
return PRECEDENCE_FUNCTIONS[n](item)
elif n in PRECEDENCE_VALUES:
return PRECEDENCE_VALUES[n]
return PRECEDENCE["Atom"]
PRECEDENCE_TRADITIONAL = PRECEDENCE.copy()
PRECEDENCE_TRADITIONAL['Integral'] = PRECEDENCE["Mul"]
PRECEDENCE_TRADITIONAL['Sum'] = PRECEDENCE["Mul"]
PRECEDENCE_TRADITIONAL['Product'] = PRECEDENCE["Mul"]
PRECEDENCE_TRADITIONAL['Limit'] = PRECEDENCE["Mul"]
PRECEDENCE_TRADITIONAL['Derivative'] = PRECEDENCE["Mul"]
PRECEDENCE_TRADITIONAL['TensorProduct'] = PRECEDENCE["Mul"]
PRECEDENCE_TRADITIONAL['Transpose'] = PRECEDENCE["Pow"]
PRECEDENCE_TRADITIONAL['Adjoint'] = PRECEDENCE["Pow"]
PRECEDENCE_TRADITIONAL['Dot'] = PRECEDENCE["Mul"] - 1
PRECEDENCE_TRADITIONAL['Cross'] = PRECEDENCE["Mul"] - 1
PRECEDENCE_TRADITIONAL['Gradient'] = PRECEDENCE["Mul"] - 1
PRECEDENCE_TRADITIONAL['Divergence'] = PRECEDENCE["Mul"] - 1
PRECEDENCE_TRADITIONAL['Curl'] = PRECEDENCE["Mul"] - 1
PRECEDENCE_TRADITIONAL['Laplacian'] = PRECEDENCE["Mul"] - 1
PRECEDENCE_TRADITIONAL['Union'] = PRECEDENCE['Xor']
PRECEDENCE_TRADITIONAL['Intersection'] = PRECEDENCE['Xor']
PRECEDENCE_TRADITIONAL['Complement'] = PRECEDENCE['Xor']
PRECEDENCE_TRADITIONAL['SymmetricDifference'] = PRECEDENCE['Xor']
PRECEDENCE_TRADITIONAL['ProductSet'] = PRECEDENCE['Xor']
def precedence_traditional(item):
"""Returns the precedence of a given object according to the
traditional rules of mathematics.
This is the precedence for the LaTeX and pretty printer.
"""
# Integral, Sum, Product, Limit have the precedence of Mul in LaTeX,
# the precedence of Atom for other printers:
from sympy.core.expr import UnevaluatedExpr
if isinstance(item, UnevaluatedExpr):
return precedence_traditional(item.args[0])
n = item.__class__.__name__
if n in PRECEDENCE_TRADITIONAL:
return PRECEDENCE_TRADITIONAL[n]
return precedence(item)

View File

@ -0,0 +1,12 @@
"""ASCII-ART 2D pretty-printer"""
from .pretty import (pretty, pretty_print, pprint, pprint_use_unicode,
pprint_try_use_unicode, pager_print)
# if unicode output is available -- let's use it
pprint_try_use_unicode()
__all__ = [
'pretty', 'pretty_print', 'pprint', 'pprint_use_unicode',
'pprint_try_use_unicode', 'pager_print',
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,732 @@
"""Symbolic primitives + unicode/ASCII abstraction for pretty.py"""
import sys
import warnings
from string import ascii_lowercase, ascii_uppercase
import unicodedata
unicode_warnings = ''
def U(name):
"""
Get a unicode character by name or, None if not found.
This exists because older versions of Python use older unicode databases.
"""
try:
return unicodedata.lookup(name)
except KeyError:
global unicode_warnings
unicode_warnings += 'No \'%s\' in unicodedata\n' % name
return None
from sympy.printing.conventions import split_super_sub
from sympy.core.alphabets import greeks
from sympy.utilities.exceptions import sympy_deprecation_warning
# prefix conventions when constructing tables
# L - LATIN i
# G - GREEK beta
# D - DIGIT 0
# S - SYMBOL +
__all__ = ['greek_unicode', 'sub', 'sup', 'xsym', 'vobj', 'hobj', 'pretty_symbol',
'annotated', 'center_pad', 'center']
_use_unicode = False
def pretty_use_unicode(flag=None):
"""Set whether pretty-printer should use unicode by default"""
global _use_unicode
global unicode_warnings
if flag is None:
return _use_unicode
if flag and unicode_warnings:
# print warnings (if any) on first unicode usage
warnings.warn(unicode_warnings)
unicode_warnings = ''
use_unicode_prev = _use_unicode
_use_unicode = flag
return use_unicode_prev
def pretty_try_use_unicode():
"""See if unicode output is available and leverage it if possible"""
encoding = getattr(sys.stdout, 'encoding', None)
# this happens when e.g. stdout is redirected through a pipe, or is
# e.g. a cStringIO.StringO
if encoding is None:
return # sys.stdout has no encoding
symbols = []
# see if we can represent greek alphabet
symbols += greek_unicode.values()
# and atoms
symbols += atoms_table.values()
for s in symbols:
if s is None:
return # common symbols not present!
try:
s.encode(encoding)
except UnicodeEncodeError:
return
# all the characters were present and encodable
pretty_use_unicode(True)
def xstr(*args):
sympy_deprecation_warning(
"""
The sympy.printing.pretty.pretty_symbology.xstr() function is
deprecated. Use str() instead.
""",
deprecated_since_version="1.7",
active_deprecations_target="deprecated-pretty-printing-functions"
)
return str(*args)
# GREEK
g = lambda l: U('GREEK SMALL LETTER %s' % l.upper())
G = lambda l: U('GREEK CAPITAL LETTER %s' % l.upper())
greek_letters = list(greeks) # make a copy
# deal with Unicode's funny spelling of lambda
greek_letters[greek_letters.index('lambda')] = 'lamda'
# {} greek letter -> (g,G)
greek_unicode = {L: g(L) for L in greek_letters}
greek_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_letters)
# aliases
greek_unicode['lambda'] = greek_unicode['lamda']
greek_unicode['Lambda'] = greek_unicode['Lamda']
greek_unicode['varsigma'] = '\N{GREEK SMALL LETTER FINAL SIGMA}'
# BOLD
b = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper())
B = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper())
bold_unicode = {l: b(l) for l in ascii_lowercase}
bold_unicode.update((L, B(L)) for L in ascii_uppercase)
# GREEK BOLD
gb = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper())
GB = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper())
greek_bold_letters = list(greeks) # make a copy, not strictly required here
# deal with Unicode's funny spelling of lambda
greek_bold_letters[greek_bold_letters.index('lambda')] = 'lamda'
# {} greek letter -> (g,G)
greek_bold_unicode = {L: g(L) for L in greek_bold_letters}
greek_bold_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_bold_letters)
greek_bold_unicode['lambda'] = greek_unicode['lamda']
greek_bold_unicode['Lambda'] = greek_unicode['Lamda']
greek_bold_unicode['varsigma'] = '\N{MATHEMATICAL BOLD SMALL FINAL SIGMA}'
digit_2txt = {
'0': 'ZERO',
'1': 'ONE',
'2': 'TWO',
'3': 'THREE',
'4': 'FOUR',
'5': 'FIVE',
'6': 'SIX',
'7': 'SEVEN',
'8': 'EIGHT',
'9': 'NINE',
}
symb_2txt = {
'+': 'PLUS SIGN',
'-': 'MINUS',
'=': 'EQUALS SIGN',
'(': 'LEFT PARENTHESIS',
')': 'RIGHT PARENTHESIS',
'[': 'LEFT SQUARE BRACKET',
']': 'RIGHT SQUARE BRACKET',
'{': 'LEFT CURLY BRACKET',
'}': 'RIGHT CURLY BRACKET',
# non-std
'{}': 'CURLY BRACKET',
'sum': 'SUMMATION',
'int': 'INTEGRAL',
}
# SUBSCRIPT & SUPERSCRIPT
LSUB = lambda letter: U('LATIN SUBSCRIPT SMALL LETTER %s' % letter.upper())
GSUB = lambda letter: U('GREEK SUBSCRIPT SMALL LETTER %s' % letter.upper())
DSUB = lambda digit: U('SUBSCRIPT %s' % digit_2txt[digit])
SSUB = lambda symb: U('SUBSCRIPT %s' % symb_2txt[symb])
LSUP = lambda letter: U('SUPERSCRIPT LATIN SMALL LETTER %s' % letter.upper())
DSUP = lambda digit: U('SUPERSCRIPT %s' % digit_2txt[digit])
SSUP = lambda symb: U('SUPERSCRIPT %s' % symb_2txt[symb])
sub = {} # symb -> subscript symbol
sup = {} # symb -> superscript symbol
# latin subscripts
for l in 'aeioruvxhklmnpst':
sub[l] = LSUB(l)
for l in 'in':
sup[l] = LSUP(l)
for gl in ['beta', 'gamma', 'rho', 'phi', 'chi']:
sub[gl] = GSUB(gl)
for d in [str(i) for i in range(10)]:
sub[d] = DSUB(d)
sup[d] = DSUP(d)
for s in '+-=()':
sub[s] = SSUB(s)
sup[s] = SSUP(s)
# Variable modifiers
# TODO: Make brackets adjust to height of contents
modifier_dict = {
# Accents
'mathring': lambda s: center_accent(s, '\N{COMBINING RING ABOVE}'),
'ddddot': lambda s: center_accent(s, '\N{COMBINING FOUR DOTS ABOVE}'),
'dddot': lambda s: center_accent(s, '\N{COMBINING THREE DOTS ABOVE}'),
'ddot': lambda s: center_accent(s, '\N{COMBINING DIAERESIS}'),
'dot': lambda s: center_accent(s, '\N{COMBINING DOT ABOVE}'),
'check': lambda s: center_accent(s, '\N{COMBINING CARON}'),
'breve': lambda s: center_accent(s, '\N{COMBINING BREVE}'),
'acute': lambda s: center_accent(s, '\N{COMBINING ACUTE ACCENT}'),
'grave': lambda s: center_accent(s, '\N{COMBINING GRAVE ACCENT}'),
'tilde': lambda s: center_accent(s, '\N{COMBINING TILDE}'),
'hat': lambda s: center_accent(s, '\N{COMBINING CIRCUMFLEX ACCENT}'),
'bar': lambda s: center_accent(s, '\N{COMBINING OVERLINE}'),
'vec': lambda s: center_accent(s, '\N{COMBINING RIGHT ARROW ABOVE}'),
'prime': lambda s: s+'\N{PRIME}',
'prm': lambda s: s+'\N{PRIME}',
# # Faces -- these are here for some compatibility with latex printing
# 'bold': lambda s: s,
# 'bm': lambda s: s,
# 'cal': lambda s: s,
# 'scr': lambda s: s,
# 'frak': lambda s: s,
# Brackets
'norm': lambda s: '\N{DOUBLE VERTICAL LINE}'+s+'\N{DOUBLE VERTICAL LINE}',
'avg': lambda s: '\N{MATHEMATICAL LEFT ANGLE BRACKET}'+s+'\N{MATHEMATICAL RIGHT ANGLE BRACKET}',
'abs': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}',
'mag': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}',
}
# VERTICAL OBJECTS
HUP = lambda symb: U('%s UPPER HOOK' % symb_2txt[symb])
CUP = lambda symb: U('%s UPPER CORNER' % symb_2txt[symb])
MID = lambda symb: U('%s MIDDLE PIECE' % symb_2txt[symb])
EXT = lambda symb: U('%s EXTENSION' % symb_2txt[symb])
HLO = lambda symb: U('%s LOWER HOOK' % symb_2txt[symb])
CLO = lambda symb: U('%s LOWER CORNER' % symb_2txt[symb])
TOP = lambda symb: U('%s TOP' % symb_2txt[symb])
BOT = lambda symb: U('%s BOTTOM' % symb_2txt[symb])
# {} '(' -> (extension, start, end, middle) 1-character
_xobj_unicode = {
# vertical symbols
# (( ext, top, bot, mid ), c1)
'(': (( EXT('('), HUP('('), HLO('(') ), '('),
')': (( EXT(')'), HUP(')'), HLO(')') ), ')'),
'[': (( EXT('['), CUP('['), CLO('[') ), '['),
']': (( EXT(']'), CUP(']'), CLO(']') ), ']'),
'{': (( EXT('{}'), HUP('{'), HLO('{'), MID('{') ), '{'),
'}': (( EXT('{}'), HUP('}'), HLO('}'), MID('}') ), '}'),
'|': U('BOX DRAWINGS LIGHT VERTICAL'),
'Tee': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'),
'UpTack': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL'),
'corner_up_centre'
'(_ext': U('LEFT PARENTHESIS EXTENSION'),
')_ext': U('RIGHT PARENTHESIS EXTENSION'),
'(_lower_hook': U('LEFT PARENTHESIS LOWER HOOK'),
')_lower_hook': U('RIGHT PARENTHESIS LOWER HOOK'),
'(_upper_hook': U('LEFT PARENTHESIS UPPER HOOK'),
')_upper_hook': U('RIGHT PARENTHESIS UPPER HOOK'),
'<': ((U('BOX DRAWINGS LIGHT VERTICAL'),
U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'),
U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT')), '<'),
'>': ((U('BOX DRAWINGS LIGHT VERTICAL'),
U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'),
U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), '>'),
'lfloor': (( EXT('['), EXT('['), CLO('[') ), U('LEFT FLOOR')),
'rfloor': (( EXT(']'), EXT(']'), CLO(']') ), U('RIGHT FLOOR')),
'lceil': (( EXT('['), CUP('['), EXT('[') ), U('LEFT CEILING')),
'rceil': (( EXT(']'), CUP(']'), EXT(']') ), U('RIGHT CEILING')),
'int': (( EXT('int'), U('TOP HALF INTEGRAL'), U('BOTTOM HALF INTEGRAL') ), U('INTEGRAL')),
'sum': (( U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), '_', U('OVERLINE'), U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), U('N-ARY SUMMATION')),
# horizontal objects
#'-': '-',
'-': U('BOX DRAWINGS LIGHT HORIZONTAL'),
'_': U('LOW LINE'),
# We used to use this, but LOW LINE looks better for roots, as it's a
# little lower (i.e., it lines up with the / perfectly. But perhaps this
# one would still be wanted for some cases?
# '_': U('HORIZONTAL SCAN LINE-9'),
# diagonal objects '\' & '/' ?
'/': U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'),
'\\': U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'),
}
_xobj_ascii = {
# vertical symbols
# (( ext, top, bot, mid ), c1)
'(': (( '|', '/', '\\' ), '('),
')': (( '|', '\\', '/' ), ')'),
# XXX this looks ugly
# '[': (( '|', '-', '-' ), '['),
# ']': (( '|', '-', '-' ), ']'),
# XXX not so ugly :(
'[': (( '[', '[', '[' ), '['),
']': (( ']', ']', ']' ), ']'),
'{': (( '|', '/', '\\', '<' ), '{'),
'}': (( '|', '\\', '/', '>' ), '}'),
'|': '|',
'<': (( '|', '/', '\\' ), '<'),
'>': (( '|', '\\', '/' ), '>'),
'int': ( ' | ', ' /', '/ ' ),
# horizontal objects
'-': '-',
'_': '_',
# diagonal objects '\' & '/' ?
'/': '/',
'\\': '\\',
}
def xobj(symb, length):
"""Construct spatial object of given length.
return: [] of equal-length strings
"""
if length <= 0:
raise ValueError("Length should be greater than 0")
# TODO robustify when no unicodedat available
if _use_unicode:
_xobj = _xobj_unicode
else:
_xobj = _xobj_ascii
vinfo = _xobj[symb]
c1 = top = bot = mid = None
if not isinstance(vinfo, tuple): # 1 entry
ext = vinfo
else:
if isinstance(vinfo[0], tuple): # (vlong), c1
vlong = vinfo[0]
c1 = vinfo[1]
else: # (vlong), c1
vlong = vinfo
ext = vlong[0]
try:
top = vlong[1]
bot = vlong[2]
mid = vlong[3]
except IndexError:
pass
if c1 is None:
c1 = ext
if top is None:
top = ext
if bot is None:
bot = ext
if mid is not None:
if (length % 2) == 0:
# even height, but we have to print it somehow anyway...
# XXX is it ok?
length += 1
else:
mid = ext
if length == 1:
return c1
res = []
next = (length - 2)//2
nmid = (length - 2) - next*2
res += [top]
res += [ext]*next
res += [mid]*nmid
res += [ext]*next
res += [bot]
return res
def vobj(symb, height):
"""Construct vertical object of a given height
see: xobj
"""
return '\n'.join( xobj(symb, height) )
def hobj(symb, width):
"""Construct horizontal object of a given width
see: xobj
"""
return ''.join( xobj(symb, width) )
# RADICAL
# n -> symbol
root = {
2: U('SQUARE ROOT'), # U('RADICAL SYMBOL BOTTOM')
3: U('CUBE ROOT'),
4: U('FOURTH ROOT'),
}
# RATIONAL
VF = lambda txt: U('VULGAR FRACTION %s' % txt)
# (p,q) -> symbol
frac = {
(1, 2): VF('ONE HALF'),
(1, 3): VF('ONE THIRD'),
(2, 3): VF('TWO THIRDS'),
(1, 4): VF('ONE QUARTER'),
(3, 4): VF('THREE QUARTERS'),
(1, 5): VF('ONE FIFTH'),
(2, 5): VF('TWO FIFTHS'),
(3, 5): VF('THREE FIFTHS'),
(4, 5): VF('FOUR FIFTHS'),
(1, 6): VF('ONE SIXTH'),
(5, 6): VF('FIVE SIXTHS'),
(1, 8): VF('ONE EIGHTH'),
(3, 8): VF('THREE EIGHTHS'),
(5, 8): VF('FIVE EIGHTHS'),
(7, 8): VF('SEVEN EIGHTHS'),
}
# atom symbols
_xsym = {
'==': ('=', '='),
'<': ('<', '<'),
'>': ('>', '>'),
'<=': ('<=', U('LESS-THAN OR EQUAL TO')),
'>=': ('>=', U('GREATER-THAN OR EQUAL TO')),
'!=': ('!=', U('NOT EQUAL TO')),
':=': (':=', ':='),
'+=': ('+=', '+='),
'-=': ('-=', '-='),
'*=': ('*=', '*='),
'/=': ('/=', '/='),
'%=': ('%=', '%='),
'*': ('*', U('DOT OPERATOR')),
'-->': ('-->', U('EM DASH') + U('EM DASH') +
U('BLACK RIGHT-POINTING TRIANGLE') if U('EM DASH')
and U('BLACK RIGHT-POINTING TRIANGLE') else None),
'==>': ('==>', U('BOX DRAWINGS DOUBLE HORIZONTAL') +
U('BOX DRAWINGS DOUBLE HORIZONTAL') +
U('BLACK RIGHT-POINTING TRIANGLE') if
U('BOX DRAWINGS DOUBLE HORIZONTAL') and
U('BOX DRAWINGS DOUBLE HORIZONTAL') and
U('BLACK RIGHT-POINTING TRIANGLE') else None),
'.': ('*', U('RING OPERATOR')),
}
def xsym(sym):
"""get symbology for a 'character'"""
op = _xsym[sym]
if _use_unicode:
return op[1]
else:
return op[0]
# SYMBOLS
atoms_table = {
# class how-to-display
'Exp1': U('SCRIPT SMALL E'),
'Pi': U('GREEK SMALL LETTER PI'),
'Infinity': U('INFINITY'),
'NegativeInfinity': U('INFINITY') and ('-' + U('INFINITY')), # XXX what to do here
#'ImaginaryUnit': U('GREEK SMALL LETTER IOTA'),
#'ImaginaryUnit': U('MATHEMATICAL ITALIC SMALL I'),
'ImaginaryUnit': U('DOUBLE-STRUCK ITALIC SMALL I'),
'EmptySet': U('EMPTY SET'),
'Naturals': U('DOUBLE-STRUCK CAPITAL N'),
'Naturals0': (U('DOUBLE-STRUCK CAPITAL N') and
(U('DOUBLE-STRUCK CAPITAL N') +
U('SUBSCRIPT ZERO'))),
'Integers': U('DOUBLE-STRUCK CAPITAL Z'),
'Rationals': U('DOUBLE-STRUCK CAPITAL Q'),
'Reals': U('DOUBLE-STRUCK CAPITAL R'),
'Complexes': U('DOUBLE-STRUCK CAPITAL C'),
'Universe': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL U'),
'IdentityMatrix': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL I'),
'ZeroMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ZERO'),
'OneMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ONE'),
'Differential': U('DOUBLE-STRUCK ITALIC SMALL D'),
'Union': U('UNION'),
'ElementOf': U('ELEMENT OF'),
'SmallElementOf': U('SMALL ELEMENT OF'),
'SymmetricDifference': U('INCREMENT'),
'Intersection': U('INTERSECTION'),
'Ring': U('RING OPERATOR'),
'Multiplication': U('MULTIPLICATION SIGN'),
'TensorProduct': U('N-ARY CIRCLED TIMES OPERATOR'),
'Dots': U('HORIZONTAL ELLIPSIS'),
'Modifier Letter Low Ring':U('Modifier Letter Low Ring'),
'EmptySequence': 'EmptySequence',
'SuperscriptPlus': U('SUPERSCRIPT PLUS SIGN'),
'SuperscriptMinus': U('SUPERSCRIPT MINUS'),
'Dagger': U('DAGGER'),
'Degree': U('DEGREE SIGN'),
#Logic Symbols
'And': U('LOGICAL AND'),
'Or': U('LOGICAL OR'),
'Not': U('NOT SIGN'),
'Nor': U('NOR'),
'Nand': U('NAND'),
'Xor': U('XOR'),
'Equiv': U('LEFT RIGHT DOUBLE ARROW'),
'NotEquiv': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'),
'Implies': U('LEFT RIGHT DOUBLE ARROW'),
'NotImplies': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'),
'Arrow': U('RIGHTWARDS ARROW'),
'ArrowFromBar': U('RIGHTWARDS ARROW FROM BAR'),
'NotArrow': U('RIGHTWARDS ARROW WITH STROKE'),
'Tautology': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'),
'Contradiction': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL')
}
def pretty_atom(atom_name, default=None, printer=None):
"""return pretty representation of an atom"""
if _use_unicode:
if printer is not None and atom_name == 'ImaginaryUnit' and printer._settings['imaginary_unit'] == 'j':
return U('DOUBLE-STRUCK ITALIC SMALL J')
else:
return atoms_table[atom_name]
else:
if default is not None:
return default
raise KeyError('only unicode') # send it default printer
def pretty_symbol(symb_name, bold_name=False):
"""return pretty representation of a symbol"""
# let's split symb_name into symbol + index
# UC: beta1
# UC: f_beta
if not _use_unicode:
return symb_name
name, sups, subs = split_super_sub(symb_name)
def translate(s, bold_name) :
if bold_name:
gG = greek_bold_unicode.get(s)
else:
gG = greek_unicode.get(s)
if gG is not None:
return gG
for key in sorted(modifier_dict.keys(), key=lambda k:len(k), reverse=True) :
if s.lower().endswith(key) and len(s)>len(key):
return modifier_dict[key](translate(s[:-len(key)], bold_name))
if bold_name:
return ''.join([bold_unicode[c] for c in s])
return s
name = translate(name, bold_name)
# Let's prettify sups/subs. If it fails at one of them, pretty sups/subs are
# not used at all.
def pretty_list(l, mapping):
result = []
for s in l:
pretty = mapping.get(s)
if pretty is None:
try: # match by separate characters
pretty = ''.join([mapping[c] for c in s])
except (TypeError, KeyError):
return None
result.append(pretty)
return result
pretty_sups = pretty_list(sups, sup)
if pretty_sups is not None:
pretty_subs = pretty_list(subs, sub)
else:
pretty_subs = None
# glue the results into one string
if pretty_subs is None: # nice formatting of sups/subs did not work
if subs:
name += '_'+'_'.join([translate(s, bold_name) for s in subs])
if sups:
name += '__'+'__'.join([translate(s, bold_name) for s in sups])
return name
else:
sups_result = ' '.join(pretty_sups)
subs_result = ' '.join(pretty_subs)
return ''.join([name, sups_result, subs_result])
def annotated(letter):
"""
Return a stylised drawing of the letter ``letter``, together with
information on how to put annotations (super- and subscripts to the
left and to the right) on it.
See pretty.py functions _print_meijerg, _print_hyper on how to use this
information.
"""
ucode_pics = {
'F': (2, 0, 2, 0, '\N{BOX DRAWINGS LIGHT DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n'
'\N{BOX DRAWINGS LIGHT VERTICAL AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n'
'\N{BOX DRAWINGS LIGHT UP}'),
'G': (3, 0, 3, 1, '\N{BOX DRAWINGS LIGHT ARC DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC DOWN AND LEFT}\n'
'\N{BOX DRAWINGS LIGHT VERTICAL}\N{BOX DRAWINGS LIGHT RIGHT}\N{BOX DRAWINGS LIGHT DOWN AND LEFT}\n'
'\N{BOX DRAWINGS LIGHT ARC UP AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC UP AND LEFT}')
}
ascii_pics = {
'F': (3, 0, 3, 0, ' _\n|_\n|\n'),
'G': (3, 0, 3, 1, ' __\n/__\n\\_|')
}
if _use_unicode:
return ucode_pics[letter]
else:
return ascii_pics[letter]
_remove_combining = dict.fromkeys(list(range(ord('\N{COMBINING GRAVE ACCENT}'), ord('\N{COMBINING LATIN SMALL LETTER X}')))
+ list(range(ord('\N{COMBINING LEFT HARPOON ABOVE}'), ord('\N{COMBINING ASTERISK ABOVE}'))))
def is_combining(sym):
"""Check whether symbol is a unicode modifier. """
return ord(sym) in _remove_combining
def center_accent(string, accent):
"""
Returns a string with accent inserted on the middle character. Useful to
put combining accents on symbol names, including multi-character names.
Parameters
==========
string : string
The string to place the accent in.
accent : string
The combining accent to insert
References
==========
.. [1] https://en.wikipedia.org/wiki/Combining_character
.. [2] https://en.wikipedia.org/wiki/Combining_Diacritical_Marks
"""
# Accent is placed on the previous character, although it may not always look
# like that depending on console
midpoint = len(string) // 2 + 1
firstpart = string[:midpoint]
secondpart = string[midpoint:]
return firstpart + accent + secondpart
def line_width(line):
"""Unicode combining symbols (modifiers) are not ever displayed as
separate symbols and thus should not be counted
"""
return len(line.translate(_remove_combining))
def is_subscriptable_in_unicode(subscript):
"""
Checks whether a string is subscriptable in unicode or not.
Parameters
==========
subscript: the string which needs to be checked
Examples
========
>>> from sympy.printing.pretty.pretty_symbology import is_subscriptable_in_unicode
>>> is_subscriptable_in_unicode('abc')
False
>>> is_subscriptable_in_unicode('123')
True
"""
return all(character in sub for character in subscript)
def center_pad(wstring, wtarget, fillchar=' '):
"""
Return the padding strings necessary to center a string of
wstring characters wide in a wtarget wide space.
The line_width wstring should always be less or equal to wtarget
or else a ValueError will be raised.
"""
if wstring > wtarget:
raise ValueError('not enough space for string')
wdelta = wtarget - wstring
wleft = wdelta // 2 # favor left '1 '
wright = wdelta - wleft
left = fillchar * wleft
right = fillchar * wright
return left, right
def center(string, width, fillchar=' '):
"""Return a centered string of length determined by `line_width`
that uses `fillchar` for padding.
"""
left, right = center_pad(line_width(string), width, fillchar)
return ''.join([left, string, right])

View File

@ -0,0 +1,537 @@
"""Prettyprinter by Jurjen Bos.
(I hate spammers: mail me at pietjepuk314 at the reverse of ku.oc.oohay).
All objects have a method that create a "stringPict",
that can be used in the str method for pretty printing.
Updates by Jason Gedge (email <my last name> at cs mun ca)
- terminal_string() method
- minor fixes and changes (mostly to prettyForm)
TODO:
- Allow left/center/right alignment options for above/below and
top/center/bottom alignment options for left/right
"""
import shutil
from .pretty_symbology import hobj, vobj, xsym, xobj, pretty_use_unicode, line_width, center
from sympy.utilities.exceptions import sympy_deprecation_warning
_GLOBAL_WRAP_LINE = None
class stringPict:
"""An ASCII picture.
The pictures are represented as a list of equal length strings.
"""
#special value for stringPict.below
LINE = 'line'
def __init__(self, s, baseline=0):
"""Initialize from string.
Multiline strings are centered.
"""
self.s = s
#picture is a string that just can be printed
self.picture = stringPict.equalLengths(s.splitlines())
#baseline is the line number of the "base line"
self.baseline = baseline
self.binding = None
@staticmethod
def equalLengths(lines):
# empty lines
if not lines:
return ['']
width = max(line_width(line) for line in lines)
return [center(line, width) for line in lines]
def height(self):
"""The height of the picture in characters."""
return len(self.picture)
def width(self):
"""The width of the picture in characters."""
return line_width(self.picture[0])
@staticmethod
def next(*args):
"""Put a string of stringPicts next to each other.
Returns string, baseline arguments for stringPict.
"""
#convert everything to stringPicts
objects = []
for arg in args:
if isinstance(arg, str):
arg = stringPict(arg)
objects.append(arg)
#make a list of pictures, with equal height and baseline
newBaseline = max(obj.baseline for obj in objects)
newHeightBelowBaseline = max(
obj.height() - obj.baseline
for obj in objects)
newHeight = newBaseline + newHeightBelowBaseline
pictures = []
for obj in objects:
oneEmptyLine = [' '*obj.width()]
basePadding = newBaseline - obj.baseline
totalPadding = newHeight - obj.height()
pictures.append(
oneEmptyLine * basePadding +
obj.picture +
oneEmptyLine * (totalPadding - basePadding))
result = [''.join(lines) for lines in zip(*pictures)]
return '\n'.join(result), newBaseline
def right(self, *args):
r"""Put pictures next to this one.
Returns string, baseline arguments for stringPict.
(Multiline) strings are allowed, and are given a baseline of 0.
Examples
========
>>> from sympy.printing.pretty.stringpict import stringPict
>>> print(stringPict("10").right(" + ",stringPict("1\r-\r2",1))[0])
1
10 + -
2
"""
return stringPict.next(self, *args)
def left(self, *args):
"""Put pictures (left to right) at left.
Returns string, baseline arguments for stringPict.
"""
return stringPict.next(*(args + (self,)))
@staticmethod
def stack(*args):
"""Put pictures on top of each other,
from top to bottom.
Returns string, baseline arguments for stringPict.
The baseline is the baseline of the second picture.
Everything is centered.
Baseline is the baseline of the second picture.
Strings are allowed.
The special value stringPict.LINE is a row of '-' extended to the width.
"""
#convert everything to stringPicts; keep LINE
objects = []
for arg in args:
if arg is not stringPict.LINE and isinstance(arg, str):
arg = stringPict(arg)
objects.append(arg)
#compute new width
newWidth = max(
obj.width()
for obj in objects
if obj is not stringPict.LINE)
lineObj = stringPict(hobj('-', newWidth))
#replace LINE with proper lines
for i, obj in enumerate(objects):
if obj is stringPict.LINE:
objects[i] = lineObj
#stack the pictures, and center the result
newPicture = [center(line, newWidth) for obj in objects for line in obj.picture]
newBaseline = objects[0].height() + objects[1].baseline
return '\n'.join(newPicture), newBaseline
def below(self, *args):
"""Put pictures under this picture.
Returns string, baseline arguments for stringPict.
Baseline is baseline of top picture
Examples
========
>>> from sympy.printing.pretty.stringpict import stringPict
>>> print(stringPict("x+3").below(
... stringPict.LINE, '3')[0]) #doctest: +NORMALIZE_WHITESPACE
x+3
---
3
"""
s, baseline = stringPict.stack(self, *args)
return s, self.baseline
def above(self, *args):
"""Put pictures above this picture.
Returns string, baseline arguments for stringPict.
Baseline is baseline of bottom picture.
"""
string, baseline = stringPict.stack(*(args + (self,)))
baseline = len(string.splitlines()) - self.height() + self.baseline
return string, baseline
def parens(self, left='(', right=')', ifascii_nougly=False):
"""Put parentheses around self.
Returns string, baseline arguments for stringPict.
left or right can be None or empty string which means 'no paren from
that side'
"""
h = self.height()
b = self.baseline
# XXX this is a hack -- ascii parens are ugly!
if ifascii_nougly and not pretty_use_unicode():
h = 1
b = 0
res = self
if left:
lparen = stringPict(vobj(left, h), baseline=b)
res = stringPict(*lparen.right(self))
if right:
rparen = stringPict(vobj(right, h), baseline=b)
res = stringPict(*res.right(rparen))
return ('\n'.join(res.picture), res.baseline)
def leftslash(self):
"""Precede object by a slash of the proper size.
"""
# XXX not used anywhere ?
height = max(
self.baseline,
self.height() - 1 - self.baseline)*2 + 1
slash = '\n'.join(
' '*(height - i - 1) + xobj('/', 1) + ' '*i
for i in range(height)
)
return self.left(stringPict(slash, height//2))
def root(self, n=None):
"""Produce a nice root symbol.
Produces ugly results for big n inserts.
"""
# XXX not used anywhere
# XXX duplicate of root drawing in pretty.py
#put line over expression
result = self.above('_'*self.width())
#construct right half of root symbol
height = self.height()
slash = '\n'.join(
' ' * (height - i - 1) + '/' + ' ' * i
for i in range(height)
)
slash = stringPict(slash, height - 1)
#left half of root symbol
if height > 2:
downline = stringPict('\\ \n \\', 1)
else:
downline = stringPict('\\')
#put n on top, as low as possible
if n is not None and n.width() > downline.width():
downline = downline.left(' '*(n.width() - downline.width()))
downline = downline.above(n)
#build root symbol
root = downline.right(slash)
#glue it on at the proper height
#normally, the root symbel is as high as self
#which is one less than result
#this moves the root symbol one down
#if the root became higher, the baseline has to grow too
root.baseline = result.baseline - result.height() + root.height()
return result.left(root)
def render(self, * args, **kwargs):
"""Return the string form of self.
Unless the argument line_break is set to False, it will
break the expression in a form that can be printed
on the terminal without being broken up.
"""
if _GLOBAL_WRAP_LINE is not None:
kwargs["wrap_line"] = _GLOBAL_WRAP_LINE
if kwargs["wrap_line"] is False:
return "\n".join(self.picture)
if kwargs["num_columns"] is not None:
# Read the argument num_columns if it is not None
ncols = kwargs["num_columns"]
else:
# Attempt to get a terminal width
ncols = self.terminal_width()
if ncols <= 0:
ncols = 80
# If smaller than the terminal width, no need to correct
if self.width() <= ncols:
return type(self.picture[0])(self)
"""
Break long-lines in a visually pleasing format.
without overflow indicators | with overflow indicators
| 2 2 3 | | 2 2 3 ↪|
|6*x *y + 4*x*y + | |6*x *y + 4*x*y + ↪|
| | | |
| 3 4 4 | |↪ 3 4 4 |
|4*y*x + x + y | |↪ 4*y*x + x + y |
|a*c*e + a*c*f + a*d | |a*c*e + a*c*f + a*d ↪|
|*e + a*d*f + b*c*e | | |
|+ b*c*f + b*d*e + b | |↪ *e + a*d*f + b*c* ↪|
|*d*f | | |
| | |↪ e + b*c*f + b*d*e ↪|
| | | |
| | |↪ + b*d*f |
"""
overflow_first = ""
if kwargs["use_unicode"] or pretty_use_unicode():
overflow_start = "\N{RIGHTWARDS ARROW WITH HOOK} "
overflow_end = " \N{RIGHTWARDS ARROW WITH HOOK}"
else:
overflow_start = "> "
overflow_end = " >"
def chunks(line):
"""Yields consecutive chunks of line_width ncols"""
prefix = overflow_first
width, start = line_width(prefix + overflow_end), 0
for i, x in enumerate(line):
wx = line_width(x)
# Only flush the screen when the current character overflows.
# This way, combining marks can be appended even when width == ncols.
if width + wx > ncols:
yield prefix + line[start:i] + overflow_end
prefix = overflow_start
width, start = line_width(prefix + overflow_end), i
width += wx
yield prefix + line[start:]
# Concurrently assemble chunks of all lines into individual screens
pictures = zip(*map(chunks, self.picture))
# Join lines of each screen into sub-pictures
pictures = ["\n".join(picture) for picture in pictures]
# Add spacers between sub-pictures
return "\n\n".join(pictures)
def terminal_width(self):
"""Return the terminal width if possible, otherwise return 0.
"""
size = shutil.get_terminal_size(fallback=(0, 0))
return size.columns
def __eq__(self, o):
if isinstance(o, str):
return '\n'.join(self.picture) == o
elif isinstance(o, stringPict):
return o.picture == self.picture
return False
def __hash__(self):
return super().__hash__()
def __str__(self):
return '\n'.join(self.picture)
def __repr__(self):
return "stringPict(%r,%d)" % ('\n'.join(self.picture), self.baseline)
def __getitem__(self, index):
return self.picture[index]
def __len__(self):
return len(self.s)
class prettyForm(stringPict):
"""
Extension of the stringPict class that knows about basic math applications,
optimizing double minus signs.
"Binding" is interpreted as follows::
ATOM this is an atom: never needs to be parenthesized
FUNC this is a function application: parenthesize if added (?)
DIV this is a division: make wider division if divided
POW this is a power: only parenthesize if exponent
MUL this is a multiplication: parenthesize if powered
ADD this is an addition: parenthesize if multiplied or powered
NEG this is a negative number: optimize if added, parenthesize if
multiplied or powered
OPEN this is an open object: parenthesize if added, multiplied, or
powered (example: Piecewise)
"""
ATOM, FUNC, DIV, POW, MUL, ADD, NEG, OPEN = range(8)
def __init__(self, s, baseline=0, binding=0, unicode=None):
"""Initialize from stringPict and binding power."""
stringPict.__init__(self, s, baseline)
self.binding = binding
if unicode is not None:
sympy_deprecation_warning(
"""
The unicode argument to prettyForm is deprecated. Only the s
argument (the first positional argument) should be passed.
""",
deprecated_since_version="1.7",
active_deprecations_target="deprecated-pretty-printing-functions")
self._unicode = unicode or s
@property
def unicode(self):
sympy_deprecation_warning(
"""
The prettyForm.unicode attribute is deprecated. Use the
prettyForm.s attribute instead.
""",
deprecated_since_version="1.7",
active_deprecations_target="deprecated-pretty-printing-functions")
return self._unicode
# Note: code to handle subtraction is in _print_Add
def __add__(self, *others):
"""Make a pretty addition.
Addition of negative numbers is simplified.
"""
arg = self
if arg.binding > prettyForm.NEG:
arg = stringPict(*arg.parens())
result = [arg]
for arg in others:
#add parentheses for weak binders
if arg.binding > prettyForm.NEG:
arg = stringPict(*arg.parens())
#use existing minus sign if available
if arg.binding != prettyForm.NEG:
result.append(' + ')
result.append(arg)
return prettyForm(binding=prettyForm.ADD, *stringPict.next(*result))
def __truediv__(self, den, slashed=False):
"""Make a pretty division; stacked or slashed.
"""
if slashed:
raise NotImplementedError("Can't do slashed fraction yet")
num = self
if num.binding == prettyForm.DIV:
num = stringPict(*num.parens())
if den.binding == prettyForm.DIV:
den = stringPict(*den.parens())
if num.binding==prettyForm.NEG:
num = num.right(" ")[0]
return prettyForm(binding=prettyForm.DIV, *stringPict.stack(
num,
stringPict.LINE,
den))
def __mul__(self, *others):
"""Make a pretty multiplication.
Parentheses are needed around +, - and neg.
"""
quantity = {
'degree': "\N{DEGREE SIGN}"
}
if len(others) == 0:
return self # We aren't actually multiplying... So nothing to do here.
# add parens on args that need them
arg = self
if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG:
arg = stringPict(*arg.parens())
result = [arg]
for arg in others:
if arg.picture[0] not in quantity.values():
result.append(xsym('*'))
#add parentheses for weak binders
if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG:
arg = stringPict(*arg.parens())
result.append(arg)
len_res = len(result)
for i in range(len_res):
if i < len_res - 1 and result[i] == '-1' and result[i + 1] == xsym('*'):
# substitute -1 by -, like in -1*x -> -x
result.pop(i)
result.pop(i)
result.insert(i, '-')
if result[0][0] == '-':
# if there is a - sign in front of all
# This test was failing to catch a prettyForm.__mul__(prettyForm("-1", 0, 6)) being negative
bin = prettyForm.NEG
if result[0] == '-':
right = result[1]
if right.picture[right.baseline][0] == '-':
result[0] = '- '
else:
bin = prettyForm.MUL
return prettyForm(binding=bin, *stringPict.next(*result))
def __repr__(self):
return "prettyForm(%r,%d,%d)" % (
'\n'.join(self.picture),
self.baseline,
self.binding)
def __pow__(self, b):
"""Make a pretty power.
"""
a = self
use_inline_func_form = False
if b.binding == prettyForm.POW:
b = stringPict(*b.parens())
if a.binding > prettyForm.FUNC:
a = stringPict(*a.parens())
elif a.binding == prettyForm.FUNC:
# heuristic for when to use inline power
if b.height() > 1:
a = stringPict(*a.parens())
else:
use_inline_func_form = True
if use_inline_func_form:
# 2
# sin + + (x)
b.baseline = a.prettyFunc.baseline + b.height()
func = stringPict(*a.prettyFunc.right(b))
return prettyForm(*func.right(a.prettyArgs))
else:
# 2 <-- top
# (x+y) <-- bot
top = stringPict(*b.left(' '*a.width()))
bot = stringPict(*a.right(' '*b.width()))
return prettyForm(binding=prettyForm.POW, *bot.above(top))
simpleFunctions = ["sin", "cos", "tan"]
@staticmethod
def apply(function, *args):
"""Functions of one or more variables.
"""
if function in prettyForm.simpleFunctions:
#simple function: use only space if possible
assert len(
args) == 1, "Simple function %s must have 1 argument" % function
arg = args[0].__pretty__()
if arg.binding <= prettyForm.DIV:
#optimization: no parentheses necessary
return prettyForm(binding=prettyForm.FUNC, *arg.left(function + ' '))
argumentList = []
for arg in args:
argumentList.append(',')
argumentList.append(arg.__pretty__())
argumentList = stringPict(*stringPict.next(*argumentList[1:]))
argumentList = stringPict(*argumentList.parens())
return prettyForm(binding=prettyForm.ATOM, *argumentList.left(function))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,390 @@
import os
from os.path import join
import shutil
import tempfile
try:
from subprocess import STDOUT, CalledProcessError, check_output
except ImportError:
pass
from sympy.utilities.decorator import doctest_depends_on
from sympy.utilities.misc import debug
from .latex import latex
__doctest_requires__ = {('preview',): ['pyglet']}
def _check_output_no_window(*args, **kwargs):
# Avoid showing a cmd.exe window when running this
# on Windows
if os.name == 'nt':
creation_flag = 0x08000000 # CREATE_NO_WINDOW
else:
creation_flag = 0 # Default value
return check_output(*args, creationflags=creation_flag, **kwargs)
def system_default_viewer(fname, fmt):
""" Open fname with the default system viewer.
In practice, it is impossible for python to know when the system viewer is
done. For this reason, we ensure the passed file will not be deleted under
it, and this function does not attempt to block.
"""
# copy to a new temporary file that will not be deleted
with tempfile.NamedTemporaryFile(prefix='sympy-preview-',
suffix=os.path.splitext(fname)[1],
delete=False) as temp_f:
with open(fname, 'rb') as f:
shutil.copyfileobj(f, temp_f)
import platform
if platform.system() == 'Darwin':
import subprocess
subprocess.call(('open', temp_f.name))
elif platform.system() == 'Windows':
os.startfile(temp_f.name)
else:
import subprocess
subprocess.call(('xdg-open', temp_f.name))
def pyglet_viewer(fname, fmt):
try:
from pyglet import window, image, gl
from pyglet.window import key
from pyglet.image.codecs import ImageDecodeException
except ImportError:
raise ImportError("pyglet is required for preview.\n visit https://pyglet.org/")
try:
img = image.load(fname)
except ImageDecodeException:
raise ValueError("pyglet preview does not work for '{}' files.".format(fmt))
offset = 25
config = gl.Config(double_buffer=False)
win = window.Window(
width=img.width + 2*offset,
height=img.height + 2*offset,
caption="SymPy",
resizable=False,
config=config
)
win.set_vsync(False)
try:
def on_close():
win.has_exit = True
win.on_close = on_close
def on_key_press(symbol, modifiers):
if symbol in [key.Q, key.ESCAPE]:
on_close()
win.on_key_press = on_key_press
def on_expose():
gl.glClearColor(1.0, 1.0, 1.0, 1.0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
img.blit(
(win.width - img.width) / 2,
(win.height - img.height) / 2
)
win.on_expose = on_expose
while not win.has_exit:
win.dispatch_events()
win.flip()
except KeyboardInterrupt:
pass
win.close()
def _get_latex_main(expr, *, preamble=None, packages=(), extra_preamble=None,
euler=True, fontsize=None, **latex_settings):
"""
Generate string of a LaTeX document rendering ``expr``.
"""
if preamble is None:
actual_packages = packages + ("amsmath", "amsfonts")
if euler:
actual_packages += ("euler",)
package_includes = "\n" + "\n".join(["\\usepackage{%s}" % p
for p in actual_packages])
if extra_preamble:
package_includes += extra_preamble
if not fontsize:
fontsize = "12pt"
elif isinstance(fontsize, int):
fontsize = "{}pt".format(fontsize)
preamble = r"""\documentclass[varwidth,%s]{standalone}
%s
\begin{document}
""" % (fontsize, package_includes)
else:
if packages or extra_preamble:
raise ValueError("The \"packages\" or \"extra_preamble\" keywords"
"must not be set if a "
"custom LaTeX preamble was specified")
if isinstance(expr, str):
latex_string = expr
else:
latex_string = ('$\\displaystyle ' +
latex(expr, mode='plain', **latex_settings) +
'$')
return preamble + '\n' + latex_string + '\n\n' + r"\end{document}"
@doctest_depends_on(exe=('latex', 'dvipng'), modules=('pyglet',),
disable_viewers=('evince', 'gimp', 'superior-dvi-viewer'))
def preview(expr, output='png', viewer=None, euler=True, packages=(),
filename=None, outputbuffer=None, preamble=None, dvioptions=None,
outputTexFile=None, extra_preamble=None, fontsize=None,
**latex_settings):
r"""
View expression or LaTeX markup in PNG, DVI, PostScript or PDF form.
If the expr argument is an expression, it will be exported to LaTeX and
then compiled using the available TeX distribution. The first argument,
'expr', may also be a LaTeX string. The function will then run the
appropriate viewer for the given output format or use the user defined
one. By default png output is generated.
By default pretty Euler fonts are used for typesetting (they were used to
typeset the well known "Concrete Mathematics" book). For that to work, you
need the 'eulervm.sty' LaTeX style (in Debian/Ubuntu, install the
texlive-fonts-extra package). If you prefer default AMS fonts or your
system lacks 'eulervm' LaTeX package then unset the 'euler' keyword
argument.
To use viewer auto-detection, lets say for 'png' output, issue
>>> from sympy import symbols, preview, Symbol
>>> x, y = symbols("x,y")
>>> preview(x + y, output='png')
This will choose 'pyglet' by default. To select a different one, do
>>> preview(x + y, output='png', viewer='gimp')
The 'png' format is considered special. For all other formats the rules
are slightly different. As an example we will take 'dvi' output format. If
you would run
>>> preview(x + y, output='dvi')
then 'view' will look for available 'dvi' viewers on your system
(predefined in the function, so it will try evince, first, then kdvi and
xdvi). If nothing is found, it will fall back to using a system file
association (via ``open`` and ``xdg-open``). To always use your system file
association without searching for the above readers, use
>>> from sympy.printing.preview import system_default_viewer
>>> preview(x + y, output='dvi', viewer=system_default_viewer)
If this still does not find the viewer you want, it can be set explicitly.
>>> preview(x + y, output='dvi', viewer='superior-dvi-viewer')
This will skip auto-detection and will run user specified
'superior-dvi-viewer'. If ``view`` fails to find it on your system it will
gracefully raise an exception.
You may also enter ``'file'`` for the viewer argument. Doing so will cause
this function to return a file object in read-only mode, if ``filename``
is unset. However, if it was set, then 'preview' writes the generated
file to this filename instead.
There is also support for writing to a ``io.BytesIO`` like object, which
needs to be passed to the ``outputbuffer`` argument.
>>> from io import BytesIO
>>> obj = BytesIO()
>>> preview(x + y, output='png', viewer='BytesIO',
... outputbuffer=obj)
The LaTeX preamble can be customized by setting the 'preamble' keyword
argument. This can be used, e.g., to set a different font size, use a
custom documentclass or import certain set of LaTeX packages.
>>> preamble = "\\documentclass[10pt]{article}\n" \
... "\\usepackage{amsmath,amsfonts}\\begin{document}"
>>> preview(x + y, output='png', preamble=preamble)
It is also possible to use the standard preamble and provide additional
information to the preamble using the ``extra_preamble`` keyword argument.
>>> from sympy import sin
>>> extra_preamble = "\\renewcommand{\\sin}{\\cos}"
>>> preview(sin(x), output='png', extra_preamble=extra_preamble)
If the value of 'output' is different from 'dvi' then command line
options can be set ('dvioptions' argument) for the execution of the
'dvi'+output conversion tool. These options have to be in the form of a
list of strings (see ``subprocess.Popen``).
Additional keyword args will be passed to the :func:`~sympy.printing.latex.latex` call,
e.g., the ``symbol_names`` flag.
>>> phidd = Symbol('phidd')
>>> preview(phidd, symbol_names={phidd: r'\ddot{\varphi}'})
For post-processing the generated TeX File can be written to a file by
passing the desired filename to the 'outputTexFile' keyword
argument. To write the TeX code to a file named
``"sample.tex"`` and run the default png viewer to display the resulting
bitmap, do
>>> preview(x + y, outputTexFile="sample.tex")
"""
# pyglet is the default for png
if viewer is None and output == "png":
try:
import pyglet # noqa: F401
except ImportError:
pass
else:
viewer = pyglet_viewer
# look up a known application
if viewer is None:
# sorted in order from most pretty to most ugly
# very discussable, but indeed 'gv' looks awful :)
candidates = {
"dvi": [ "evince", "okular", "kdvi", "xdvi" ],
"ps": [ "evince", "okular", "gsview", "gv" ],
"pdf": [ "evince", "okular", "kpdf", "acroread", "xpdf", "gv" ],
}
for candidate in candidates.get(output, []):
path = shutil.which(candidate)
if path is not None:
viewer = path
break
# otherwise, use the system default for file association
if viewer is None:
viewer = system_default_viewer
if viewer == "file":
if filename is None:
raise ValueError("filename has to be specified if viewer=\"file\"")
elif viewer == "BytesIO":
if outputbuffer is None:
raise ValueError("outputbuffer has to be a BytesIO "
"compatible object if viewer=\"BytesIO\"")
elif not callable(viewer) and not shutil.which(viewer):
raise OSError("Unrecognized viewer: %s" % viewer)
latex_main = _get_latex_main(expr, preamble=preamble, packages=packages,
euler=euler, extra_preamble=extra_preamble,
fontsize=fontsize, **latex_settings)
debug("Latex code:")
debug(latex_main)
with tempfile.TemporaryDirectory() as workdir:
with open(join(workdir, 'texput.tex'), 'w', encoding='utf-8') as fh:
fh.write(latex_main)
if outputTexFile is not None:
shutil.copyfile(join(workdir, 'texput.tex'), outputTexFile)
if not shutil.which('latex'):
raise RuntimeError("latex program is not installed")
try:
_check_output_no_window(
['latex', '-halt-on-error', '-interaction=nonstopmode',
'texput.tex'],
cwd=workdir,
stderr=STDOUT)
except CalledProcessError as e:
raise RuntimeError(
"'latex' exited abnormally with the following output:\n%s" %
e.output)
src = "texput.%s" % (output)
if output != "dvi":
# in order of preference
commandnames = {
"ps": ["dvips"],
"pdf": ["dvipdfmx", "dvipdfm", "dvipdf"],
"png": ["dvipng"],
"svg": ["dvisvgm"],
}
try:
cmd_variants = commandnames[output]
except KeyError:
raise ValueError("Invalid output format: %s" % output) from None
# find an appropriate command
for cmd_variant in cmd_variants:
cmd_path = shutil.which(cmd_variant)
if cmd_path:
cmd = [cmd_path]
break
else:
if len(cmd_variants) > 1:
raise RuntimeError("None of %s are installed" % ", ".join(cmd_variants))
else:
raise RuntimeError("%s is not installed" % cmd_variants[0])
defaultoptions = {
"dvipng": ["-T", "tight", "-z", "9", "--truecolor"],
"dvisvgm": ["--no-fonts"],
}
commandend = {
"dvips": ["-o", src, "texput.dvi"],
"dvipdf": ["texput.dvi", src],
"dvipdfm": ["-o", src, "texput.dvi"],
"dvipdfmx": ["-o", src, "texput.dvi"],
"dvipng": ["-o", src, "texput.dvi"],
"dvisvgm": ["-o", src, "texput.dvi"],
}
if dvioptions is not None:
cmd.extend(dvioptions)
else:
cmd.extend(defaultoptions.get(cmd_variant, []))
cmd.extend(commandend[cmd_variant])
try:
_check_output_no_window(cmd, cwd=workdir, stderr=STDOUT)
except CalledProcessError as e:
raise RuntimeError(
"'%s' exited abnormally with the following output:\n%s" %
(' '.join(cmd), e.output))
if viewer == "file":
shutil.move(join(workdir, src), filename)
elif viewer == "BytesIO":
with open(join(workdir, src), 'rb') as fh:
outputbuffer.write(fh.read())
elif callable(viewer):
viewer(join(workdir, src), fmt=output)
else:
try:
_check_output_no_window(
[viewer, src], cwd=workdir, stderr=STDOUT)
except CalledProcessError as e:
raise RuntimeError(
"'%s %s' exited abnormally with the following output:\n%s" %
(viewer, src, e.output))

View File

@ -0,0 +1,396 @@
"""Printing subsystem driver
SymPy's printing system works the following way: Any expression can be
passed to a designated Printer who then is responsible to return an
adequate representation of that expression.
**The basic concept is the following:**
1. Let the object print itself if it knows how.
2. Take the best fitting method defined in the printer.
3. As fall-back use the emptyPrinter method for the printer.
Which Method is Responsible for Printing?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The whole printing process is started by calling ``.doprint(expr)`` on the printer
which you want to use. This method looks for an appropriate method which can
print the given expression in the given style that the printer defines.
While looking for the method, it follows these steps:
1. **Let the object print itself if it knows how.**
The printer looks for a specific method in every object. The name of that method
depends on the specific printer and is defined under ``Printer.printmethod``.
For example, StrPrinter calls ``_sympystr`` and LatexPrinter calls ``_latex``.
Look at the documentation of the printer that you want to use.
The name of the method is specified there.
This was the original way of doing printing in sympy. Every class had
its own latex, mathml, str and repr methods, but it turned out that it
is hard to produce a high quality printer, if all the methods are spread
out that far. Therefore all printing code was combined into the different
printers, which works great for built-in SymPy objects, but not that
good for user defined classes where it is inconvenient to patch the
printers.
2. **Take the best fitting method defined in the printer.**
The printer loops through expr classes (class + its bases), and tries
to dispatch the work to ``_print_<EXPR_CLASS>``
e.g., suppose we have the following class hierarchy::
Basic
|
Atom
|
Number
|
Rational
then, for ``expr=Rational(...)``, the Printer will try
to call printer methods in the order as shown in the figure below::
p._print(expr)
|
|-- p._print_Rational(expr)
|
|-- p._print_Number(expr)
|
|-- p._print_Atom(expr)
|
`-- p._print_Basic(expr)
if ``._print_Rational`` method exists in the printer, then it is called,
and the result is returned back. Otherwise, the printer tries to call
``._print_Number`` and so on.
3. **As a fall-back use the emptyPrinter method for the printer.**
As fall-back ``self.emptyPrinter`` will be called with the expression. If
not defined in the Printer subclass this will be the same as ``str(expr)``.
.. _printer_example:
Example of Custom Printer
^^^^^^^^^^^^^^^^^^^^^^^^^
In the example below, we have a printer which prints the derivative of a function
in a shorter form.
.. code-block:: python
from sympy.core.symbol import Symbol
from sympy.printing.latex import LatexPrinter, print_latex
from sympy.core.function import UndefinedFunction, Function
class MyLatexPrinter(LatexPrinter):
\"\"\"Print derivative of a function of symbols in a shorter form.
\"\"\"
def _print_Derivative(self, expr):
function, *vars = expr.args
if not isinstance(type(function), UndefinedFunction) or \\
not all(isinstance(i, Symbol) for i in vars):
return super()._print_Derivative(expr)
# If you want the printer to work correctly for nested
# expressions then use self._print() instead of str() or latex().
# See the example of nested modulo below in the custom printing
# method section.
return "{}_{{{}}}".format(
self._print(Symbol(function.func.__name__)),
''.join(self._print(i) for i in vars))
def print_my_latex(expr):
\"\"\" Most of the printers define their own wrappers for print().
These wrappers usually take printer settings. Our printer does not have
any settings.
\"\"\"
print(MyLatexPrinter().doprint(expr))
y = Symbol("y")
x = Symbol("x")
f = Function("f")
expr = f(x, y).diff(x, y)
# Print the expression using the normal latex printer and our custom
# printer.
print_latex(expr)
print_my_latex(expr)
The output of the code above is::
\\frac{\\partial^{2}}{\\partial x\\partial y} f{\\left(x,y \\right)}
f_{xy}
.. _printer_method_example:
Example of Custom Printing Method
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the example below, the latex printing of the modulo operator is modified.
This is done by overriding the method ``_latex`` of ``Mod``.
>>> from sympy import Symbol, Mod, Integer, print_latex
>>> # Always use printer._print()
>>> class ModOp(Mod):
... def _latex(self, printer):
... a, b = [printer._print(i) for i in self.args]
... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b)
Comparing the output of our custom operator to the builtin one:
>>> x = Symbol('x')
>>> m = Symbol('m')
>>> print_latex(Mod(x, m))
x \\bmod m
>>> print_latex(ModOp(x, m))
\\operatorname{Mod}{\\left(x, m\\right)}
Common mistakes
~~~~~~~~~~~~~~~
It's important to always use ``self._print(obj)`` to print subcomponents of
an expression when customizing a printer. Mistakes include:
1. Using ``self.doprint(obj)`` instead:
>>> # This example does not work properly, as only the outermost call may use
>>> # doprint.
>>> class ModOpModeWrong(Mod):
... def _latex(self, printer):
... a, b = [printer.doprint(i) for i in self.args]
... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b)
This fails when the ``mode`` argument is passed to the printer:
>>> print_latex(ModOp(x, m), mode='inline') # ok
$\\operatorname{Mod}{\\left(x, m\\right)}$
>>> print_latex(ModOpModeWrong(x, m), mode='inline') # bad
$\\operatorname{Mod}{\\left($x$, $m$\\right)}$
2. Using ``str(obj)`` instead:
>>> class ModOpNestedWrong(Mod):
... def _latex(self, printer):
... a, b = [str(i) for i in self.args]
... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b)
This fails on nested objects:
>>> # Nested modulo.
>>> print_latex(ModOp(ModOp(x, m), Integer(7))) # ok
\\operatorname{Mod}{\\left(\\operatorname{Mod}{\\left(x, m\\right)}, 7\\right)}
>>> print_latex(ModOpNestedWrong(ModOpNestedWrong(x, m), Integer(7))) # bad
\\operatorname{Mod}{\\left(ModOpNestedWrong(x, m), 7\\right)}
3. Using ``LatexPrinter()._print(obj)`` instead.
>>> from sympy.printing.latex import LatexPrinter
>>> class ModOpSettingsWrong(Mod):
... def _latex(self, printer):
... a, b = [LatexPrinter()._print(i) for i in self.args]
... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b)
This causes all the settings to be discarded in the subobjects. As an
example, the ``full_prec`` setting which shows floats to full precision is
ignored:
>>> from sympy import Float
>>> print_latex(ModOp(Float(1) * x, m), full_prec=True) # ok
\\operatorname{Mod}{\\left(1.00000000000000 x, m\\right)}
>>> print_latex(ModOpSettingsWrong(Float(1) * x, m), full_prec=True) # bad
\\operatorname{Mod}{\\left(1.0 x, m\\right)}
"""
from __future__ import annotations
import sys
from typing import Any, Type
import inspect
from contextlib import contextmanager
from functools import cmp_to_key, update_wrapper
from sympy.core.add import Add
from sympy.core.basic import Basic
from sympy.core.function import AppliedUndef, UndefinedFunction, Function
@contextmanager
def printer_context(printer, **kwargs):
original = printer._context.copy()
try:
printer._context.update(kwargs)
yield
finally:
printer._context = original
class Printer:
""" Generic printer
Its job is to provide infrastructure for implementing new printers easily.
If you want to define your custom Printer or your custom printing method
for your custom class then see the example above: printer_example_ .
"""
_global_settings: dict[str, Any] = {}
_default_settings: dict[str, Any] = {}
printmethod = None # type: str
@classmethod
def _get_initial_settings(cls):
settings = cls._default_settings.copy()
for key, val in cls._global_settings.items():
if key in cls._default_settings:
settings[key] = val
return settings
def __init__(self, settings=None):
self._str = str
self._settings = self._get_initial_settings()
self._context = {} # mutable during printing
if settings is not None:
self._settings.update(settings)
if len(self._settings) > len(self._default_settings):
for key in self._settings:
if key not in self._default_settings:
raise TypeError("Unknown setting '%s'." % key)
# _print_level is the number of times self._print() was recursively
# called. See StrPrinter._print_Float() for an example of usage
self._print_level = 0
@classmethod
def set_global_settings(cls, **settings):
"""Set system-wide printing settings. """
for key, val in settings.items():
if val is not None:
cls._global_settings[key] = val
@property
def order(self):
if 'order' in self._settings:
return self._settings['order']
else:
raise AttributeError("No order defined.")
def doprint(self, expr):
"""Returns printer's representation for expr (as a string)"""
return self._str(self._print(expr))
def _print(self, expr, **kwargs) -> str:
"""Internal dispatcher
Tries the following concepts to print an expression:
1. Let the object print itself if it knows how.
2. Take the best fitting method defined in the printer.
3. As fall-back use the emptyPrinter method for the printer.
"""
self._print_level += 1
try:
# If the printer defines a name for a printing method
# (Printer.printmethod) and the object knows for itself how it
# should be printed, use that method.
if self.printmethod and hasattr(expr, self.printmethod):
if not (isinstance(expr, type) and issubclass(expr, Basic)):
return getattr(expr, self.printmethod)(self, **kwargs)
# See if the class of expr is known, or if one of its super
# classes is known, and use that print function
# Exception: ignore the subclasses of Undefined, so that, e.g.,
# Function('gamma') does not get dispatched to _print_gamma
classes = type(expr).__mro__
if AppliedUndef in classes:
classes = classes[classes.index(AppliedUndef):]
if UndefinedFunction in classes:
classes = classes[classes.index(UndefinedFunction):]
# Another exception: if someone subclasses a known function, e.g.,
# gamma, and changes the name, then ignore _print_gamma
if Function in classes:
i = classes.index(Function)
classes = tuple(c for c in classes[:i] if \
c.__name__ == classes[0].__name__ or \
c.__name__.endswith("Base")) + classes[i:]
for cls in classes:
printmethodname = '_print_' + cls.__name__
printmethod = getattr(self, printmethodname, None)
if printmethod is not None:
return printmethod(expr, **kwargs)
# Unknown object, fall back to the emptyPrinter.
return self.emptyPrinter(expr)
finally:
self._print_level -= 1
def emptyPrinter(self, expr):
return str(expr)
def _as_ordered_terms(self, expr, order=None):
"""A compatibility function for ordering terms in Add. """
order = order or self.order
if order == 'old':
return sorted(Add.make_args(expr), key=cmp_to_key(Basic._compare_pretty))
elif order == 'none':
return list(expr.args)
else:
return expr.as_ordered_terms(order=order)
class _PrintFunction:
"""
Function wrapper to replace ``**settings`` in the signature with printer defaults
"""
def __init__(self, f, print_cls: Type[Printer]):
# find all the non-setting arguments
params = list(inspect.signature(f).parameters.values())
assert params.pop(-1).kind == inspect.Parameter.VAR_KEYWORD
self.__other_params = params
self.__print_cls = print_cls
update_wrapper(self, f)
def __reduce__(self):
# Since this is used as a decorator, it replaces the original function.
# The default pickling will try to pickle self.__wrapped__ and fail
# because the wrapped function can't be retrieved by name.
return self.__wrapped__.__qualname__
def __call__(self, *args, **kwargs):
return self.__wrapped__(*args, **kwargs)
@property
def __signature__(self) -> inspect.Signature:
settings = self.__print_cls._get_initial_settings()
return inspect.Signature(
parameters=self.__other_params + [
inspect.Parameter(k, inspect.Parameter.KEYWORD_ONLY, default=v)
for k, v in settings.items()
],
return_annotation=self.__wrapped__.__annotations__.get('return', inspect.Signature.empty) # type:ignore
)
def print_function(print_cls):
""" A decorator to replace kwargs with the printer settings in __signature__ """
def decorator(f):
if sys.version_info < (3, 9):
# We have to create a subclass so that `help` actually shows the docstring in older Python versions.
# IPython and Sphinx do not need this, only a raw Python console.
cls = type(f'{f.__qualname__}_PrintFunction', (_PrintFunction,), {"__doc__": f.__doc__})
else:
cls = _PrintFunction
return cls(f, print_cls)
return decorator

View File

@ -0,0 +1,772 @@
"""
Python code printers
This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code.
"""
from collections import defaultdict
from itertools import chain
from sympy.core import S
from sympy.core.mod import Mod
from .precedence import precedence
from .codeprinter import CodePrinter
_kw = {
'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in',
'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
'with', 'yield', 'None', 'False', 'nonlocal', 'True'
}
_known_functions = {
'Abs': 'abs',
'Min': 'min',
'Max': 'max',
}
_known_functions_math = {
'acos': 'acos',
'acosh': 'acosh',
'asin': 'asin',
'asinh': 'asinh',
'atan': 'atan',
'atan2': 'atan2',
'atanh': 'atanh',
'ceiling': 'ceil',
'cos': 'cos',
'cosh': 'cosh',
'erf': 'erf',
'erfc': 'erfc',
'exp': 'exp',
'expm1': 'expm1',
'factorial': 'factorial',
'floor': 'floor',
'gamma': 'gamma',
'hypot': 'hypot',
'isnan': 'isnan',
'loggamma': 'lgamma',
'log': 'log',
'ln': 'log',
'log10': 'log10',
'log1p': 'log1p',
'log2': 'log2',
'sin': 'sin',
'sinh': 'sinh',
'Sqrt': 'sqrt',
'tan': 'tan',
'tanh': 'tanh'
} # Not used from ``math``: [copysign isclose isfinite isinf ldexp frexp pow modf
# radians trunc fmod fsum gcd degrees fabs]
_known_constants_math = {
'Exp1': 'e',
'Pi': 'pi',
'E': 'e',
'Infinity': 'inf',
'NaN': 'nan',
'ComplexInfinity': 'nan'
}
def _print_known_func(self, expr):
known = self.known_functions[expr.__class__.__name__]
return '{name}({args})'.format(name=self._module_format(known),
args=', '.join((self._print(arg) for arg in expr.args)))
def _print_known_const(self, expr):
known = self.known_constants[expr.__class__.__name__]
return self._module_format(known)
class AbstractPythonCodePrinter(CodePrinter):
printmethod = "_pythoncode"
language = "Python"
reserved_words = _kw
modules = None # initialized to a set in __init__
tab = ' '
_kf = dict(chain(
_known_functions.items(),
[(k, 'math.' + v) for k, v in _known_functions_math.items()]
))
_kc = {k: 'math.'+v for k, v in _known_constants_math.items()}
_operators = {'and': 'and', 'or': 'or', 'not': 'not'}
_default_settings = dict(
CodePrinter._default_settings,
user_functions={},
precision=17,
inline=True,
fully_qualified_modules=True,
contract=False,
standard='python3',
)
def __init__(self, settings=None):
super().__init__(settings)
# Python standard handler
std = self._settings['standard']
if std is None:
import sys
std = 'python{}'.format(sys.version_info.major)
if std != 'python3':
raise ValueError('Only Python 3 is supported.')
self.standard = std
self.module_imports = defaultdict(set)
# Known functions and constants handler
self.known_functions = dict(self._kf, **(settings or {}).get(
'user_functions', {}))
self.known_constants = dict(self._kc, **(settings or {}).get(
'user_constants', {}))
def _declare_number_const(self, name, value):
return "%s = %s" % (name, value)
def _module_format(self, fqn, register=True):
parts = fqn.split('.')
if register and len(parts) > 1:
self.module_imports['.'.join(parts[:-1])].add(parts[-1])
if self._settings['fully_qualified_modules']:
return fqn
else:
return fqn.split('(')[0].split('[')[0].split('.')[-1]
def _format_code(self, lines):
return lines
def _get_statement(self, codestring):
return "{}".format(codestring)
def _get_comment(self, text):
return " # {}".format(text)
def _expand_fold_binary_op(self, op, args):
"""
This method expands a fold on binary operations.
``functools.reduce`` is an example of a folded operation.
For example, the expression
`A + B + C + D`
is folded into
`((A + B) + C) + D`
"""
if len(args) == 1:
return self._print(args[0])
else:
return "%s(%s, %s)" % (
self._module_format(op),
self._expand_fold_binary_op(op, args[:-1]),
self._print(args[-1]),
)
def _expand_reduce_binary_op(self, op, args):
"""
This method expands a reductin on binary operations.
Notice: this is NOT the same as ``functools.reduce``.
For example, the expression
`A + B + C + D`
is reduced into:
`(A + B) + (C + D)`
"""
if len(args) == 1:
return self._print(args[0])
else:
N = len(args)
Nhalf = N // 2
return "%s(%s, %s)" % (
self._module_format(op),
self._expand_reduce_binary_op(args[:Nhalf]),
self._expand_reduce_binary_op(args[Nhalf:]),
)
def _print_NaN(self, expr):
return "float('nan')"
def _print_Infinity(self, expr):
return "float('inf')"
def _print_NegativeInfinity(self, expr):
return "float('-inf')"
def _print_ComplexInfinity(self, expr):
return self._print_NaN(expr)
def _print_Mod(self, expr):
PREC = precedence(expr)
return ('{} % {}'.format(*(self.parenthesize(x, PREC) for x in expr.args)))
def _print_Piecewise(self, expr):
result = []
i = 0
for arg in expr.args:
e = arg.expr
c = arg.cond
if i == 0:
result.append('(')
result.append('(')
result.append(self._print(e))
result.append(')')
result.append(' if ')
result.append(self._print(c))
result.append(' else ')
i += 1
result = result[:-1]
if result[-1] == 'True':
result = result[:-2]
result.append(')')
else:
result.append(' else None)')
return ''.join(result)
def _print_Relational(self, expr):
"Relational printer for Equality and Unequality"
op = {
'==' :'equal',
'!=' :'not_equal',
'<' :'less',
'<=' :'less_equal',
'>' :'greater',
'>=' :'greater_equal',
}
if expr.rel_op in op:
lhs = self._print(expr.lhs)
rhs = self._print(expr.rhs)
return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs)
return super()._print_Relational(expr)
def _print_ITE(self, expr):
from sympy.functions.elementary.piecewise import Piecewise
return self._print(expr.rewrite(Piecewise))
def _print_Sum(self, expr):
loops = (
'for {i} in range({a}, {b}+1)'.format(
i=self._print(i),
a=self._print(a),
b=self._print(b))
for i, a, b in expr.limits)
return '(builtins.sum({function} {loops}))'.format(
function=self._print(expr.function),
loops=' '.join(loops))
def _print_ImaginaryUnit(self, expr):
return '1j'
def _print_KroneckerDelta(self, expr):
a, b = expr.args
return '(1 if {a} == {b} else 0)'.format(
a = self._print(a),
b = self._print(b)
)
def _print_MatrixBase(self, expr):
name = expr.__class__.__name__
func = self.known_functions.get(name, name)
return "%s(%s)" % (func, self._print(expr.tolist()))
_print_SparseRepMatrix = \
_print_MutableSparseMatrix = \
_print_ImmutableSparseMatrix = \
_print_Matrix = \
_print_DenseMatrix = \
_print_MutableDenseMatrix = \
_print_ImmutableMatrix = \
_print_ImmutableDenseMatrix = \
lambda self, expr: self._print_MatrixBase(expr)
def _indent_codestring(self, codestring):
return '\n'.join([self.tab + line for line in codestring.split('\n')])
def _print_FunctionDefinition(self, fd):
body = '\n'.join((self._print(arg) for arg in fd.body))
return "def {name}({parameters}):\n{body}".format(
name=self._print(fd.name),
parameters=', '.join([self._print(var.symbol) for var in fd.parameters]),
body=self._indent_codestring(body)
)
def _print_While(self, whl):
body = '\n'.join((self._print(arg) for arg in whl.body))
return "while {cond}:\n{body}".format(
cond=self._print(whl.condition),
body=self._indent_codestring(body)
)
def _print_Declaration(self, decl):
return '%s = %s' % (
self._print(decl.variable.symbol),
self._print(decl.variable.value)
)
def _print_BreakToken(self, bt):
return 'break'
def _print_Return(self, ret):
arg, = ret.args
return 'return %s' % self._print(arg)
def _print_Raise(self, rs):
arg, = rs.args
return 'raise %s' % self._print(arg)
def _print_RuntimeError_(self, re):
message, = re.args
return "RuntimeError(%s)" % self._print(message)
def _print_Print(self, prnt):
print_args = ', '.join((self._print(arg) for arg in prnt.print_args))
from sympy.codegen.ast import none
if prnt.format_string != none:
print_args = '{} % ({}), end=""'.format(
self._print(prnt.format_string),
print_args
)
if prnt.file != None: # Must be '!= None', cannot be 'is not None'
print_args += ', file=%s' % self._print(prnt.file)
return 'print(%s)' % print_args
def _print_Stream(self, strm):
if str(strm.name) == 'stdout':
return self._module_format('sys.stdout')
elif str(strm.name) == 'stderr':
return self._module_format('sys.stderr')
else:
return self._print(strm.name)
def _print_NoneToken(self, arg):
return 'None'
def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'):
"""Printing helper function for ``Pow``
Notes
=====
This preprocesses the ``sqrt`` as math formatter and prints division
Examples
========
>>> from sympy import sqrt
>>> from sympy.printing.pycode import PythonCodePrinter
>>> from sympy.abc import x
Python code printer automatically looks up ``math.sqrt``.
>>> printer = PythonCodePrinter()
>>> printer._hprint_Pow(sqrt(x), rational=True)
'x**(1/2)'
>>> printer._hprint_Pow(sqrt(x), rational=False)
'math.sqrt(x)'
>>> printer._hprint_Pow(1/sqrt(x), rational=True)
'x**(-1/2)'
>>> printer._hprint_Pow(1/sqrt(x), rational=False)
'1/math.sqrt(x)'
>>> printer._hprint_Pow(1/x, rational=False)
'1/x'
>>> printer._hprint_Pow(1/x, rational=True)
'x**(-1)'
Using sqrt from numpy or mpmath
>>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt')
'numpy.sqrt(x)'
>>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt')
'mpmath.sqrt(x)'
See Also
========
sympy.printing.str.StrPrinter._print_Pow
"""
PREC = precedence(expr)
if expr.exp == S.Half and not rational:
func = self._module_format(sqrt)
arg = self._print(expr.base)
return '{func}({arg})'.format(func=func, arg=arg)
if expr.is_commutative and not rational:
if -expr.exp is S.Half:
func = self._module_format(sqrt)
num = self._print(S.One)
arg = self._print(expr.base)
return f"{num}/{func}({arg})"
if expr.exp is S.NegativeOne:
num = self._print(S.One)
arg = self.parenthesize(expr.base, PREC, strict=False)
return f"{num}/{arg}"
base_str = self.parenthesize(expr.base, PREC, strict=False)
exp_str = self.parenthesize(expr.exp, PREC, strict=False)
return "{}**{}".format(base_str, exp_str)
class ArrayPrinter:
def _arrayify(self, indexed):
from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array
try:
return convert_indexed_to_array(indexed)
except Exception:
return indexed
def _get_einsum_string(self, subranks, contraction_indices):
letters = self._get_letter_generator_for_einsum()
contraction_string = ""
counter = 0
d = {j: min(i) for i in contraction_indices for j in i}
indices = []
for rank_arg in subranks:
lindices = []
for i in range(rank_arg):
if counter in d:
lindices.append(d[counter])
else:
lindices.append(counter)
counter += 1
indices.append(lindices)
mapping = {}
letters_free = []
letters_dum = []
for i in indices:
for j in i:
if j not in mapping:
l = next(letters)
mapping[j] = l
else:
l = mapping[j]
contraction_string += l
if j in d:
if l not in letters_dum:
letters_dum.append(l)
else:
letters_free.append(l)
contraction_string += ","
contraction_string = contraction_string[:-1]
return contraction_string, letters_free, letters_dum
def _get_letter_generator_for_einsum(self):
for i in range(97, 123):
yield chr(i)
for i in range(65, 91):
yield chr(i)
raise ValueError("out of letters")
def _print_ArrayTensorProduct(self, expr):
letters = self._get_letter_generator_for_einsum()
contraction_string = ",".join(["".join([next(letters) for j in range(i)]) for i in expr.subranks])
return '%s("%s", %s)' % (
self._module_format(self._module + "." + self._einsum),
contraction_string,
", ".join([self._print(arg) for arg in expr.args])
)
def _print_ArrayContraction(self, expr):
from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
base = expr.expr
contraction_indices = expr.contraction_indices
if isinstance(base, ArrayTensorProduct):
elems = ",".join(["%s" % (self._print(arg)) for arg in base.args])
ranks = base.subranks
else:
elems = self._print(base)
ranks = [len(base.shape)]
contraction_string, letters_free, letters_dum = self._get_einsum_string(ranks, contraction_indices)
if not contraction_indices:
return self._print(base)
if isinstance(base, ArrayTensorProduct):
elems = ",".join(["%s" % (self._print(arg)) for arg in base.args])
else:
elems = self._print(base)
return "%s(\"%s\", %s)" % (
self._module_format(self._module + "." + self._einsum),
"{}->{}".format(contraction_string, "".join(sorted(letters_free))),
elems,
)
def _print_ArrayDiagonal(self, expr):
from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
diagonal_indices = list(expr.diagonal_indices)
if isinstance(expr.expr, ArrayTensorProduct):
subranks = expr.expr.subranks
elems = expr.expr.args
else:
subranks = expr.subranks
elems = [expr.expr]
diagonal_string, letters_free, letters_dum = self._get_einsum_string(subranks, diagonal_indices)
elems = [self._print(i) for i in elems]
return '%s("%s", %s)' % (
self._module_format(self._module + "." + self._einsum),
"{}->{}".format(diagonal_string, "".join(letters_free+letters_dum)),
", ".join(elems)
)
def _print_PermuteDims(self, expr):
return "%s(%s, %s)" % (
self._module_format(self._module + "." + self._transpose),
self._print(expr.expr),
self._print(expr.permutation.array_form),
)
def _print_ArrayAdd(self, expr):
return self._expand_fold_binary_op(self._module + "." + self._add, expr.args)
def _print_OneArray(self, expr):
return "%s((%s,))" % (
self._module_format(self._module+ "." + self._ones),
','.join(map(self._print,expr.args))
)
def _print_ZeroArray(self, expr):
return "%s((%s,))" % (
self._module_format(self._module+ "." + self._zeros),
','.join(map(self._print,expr.args))
)
def _print_Assignment(self, expr):
#XXX: maybe this needs to happen at a higher level e.g. at _print or
#doprint?
lhs = self._print(self._arrayify(expr.lhs))
rhs = self._print(self._arrayify(expr.rhs))
return "%s = %s" % ( lhs, rhs )
def _print_IndexedBase(self, expr):
return self._print_ArraySymbol(expr)
class PythonCodePrinter(AbstractPythonCodePrinter):
def _print_sign(self, e):
return '(0.0 if {e} == 0 else {f}(1, {e}))'.format(
f=self._module_format('math.copysign'), e=self._print(e.args[0]))
def _print_Not(self, expr):
PREC = precedence(expr)
return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
def _print_IndexedBase(self, expr):
return expr.name
def _print_Indexed(self, expr):
base = expr.args[0]
index = expr.args[1:]
return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index]))
def _print_Pow(self, expr, rational=False):
return self._hprint_Pow(expr, rational=rational)
def _print_Rational(self, expr):
return '{}/{}'.format(expr.p, expr.q)
def _print_Half(self, expr):
return self._print_Rational(expr)
def _print_frac(self, expr):
return self._print_Mod(Mod(expr.args[0], 1))
def _print_Symbol(self, expr):
name = super()._print_Symbol(expr)
if name in self.reserved_words:
if self._settings['error_on_reserved']:
msg = ('This expression includes the symbol "{}" which is a '
'reserved keyword in this language.')
raise ValueError(msg.format(name))
return name + self._settings['reserved_word_suffix']
elif '{' in name: # Remove curly braces from subscripted variables
return name.replace('{', '').replace('}', '')
else:
return name
_print_lowergamma = CodePrinter._print_not_supported
_print_uppergamma = CodePrinter._print_not_supported
_print_fresnelc = CodePrinter._print_not_supported
_print_fresnels = CodePrinter._print_not_supported
for k in PythonCodePrinter._kf:
setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func)
for k in _known_constants_math:
setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const)
def pycode(expr, **settings):
""" Converts an expr to a string of Python code
Parameters
==========
expr : Expr
A SymPy expression.
fully_qualified_modules : bool
Whether or not to write out full module names of functions
(``math.sin`` vs. ``sin``). default: ``True``.
standard : str or None, optional
Only 'python3' (default) is supported.
This parameter may be removed in the future.
Examples
========
>>> from sympy import pycode, tan, Symbol
>>> pycode(tan(Symbol('x')) + 1)
'math.tan(x) + 1'
"""
return PythonCodePrinter(settings).doprint(expr)
_not_in_mpmath = 'log1p log2'.split()
_in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath]
_known_functions_mpmath = dict(_in_mpmath, **{
'beta': 'beta',
'frac': 'frac',
'fresnelc': 'fresnelc',
'fresnels': 'fresnels',
'sign': 'sign',
'loggamma': 'loggamma',
'hyper': 'hyper',
'meijerg': 'meijerg',
'besselj': 'besselj',
'bessely': 'bessely',
'besseli': 'besseli',
'besselk': 'besselk',
})
_known_constants_mpmath = {
'Exp1': 'e',
'Pi': 'pi',
'GoldenRatio': 'phi',
'EulerGamma': 'euler',
'Catalan': 'catalan',
'NaN': 'nan',
'Infinity': 'inf',
'NegativeInfinity': 'ninf'
}
def _unpack_integral_limits(integral_expr):
""" helper function for _print_Integral that
- accepts an Integral expression
- returns a tuple of
- a list variables of integration
- a list of tuples of the upper and lower limits of integration
"""
integration_vars = []
limits = []
for integration_range in integral_expr.limits:
if len(integration_range) == 3:
integration_var, lower_limit, upper_limit = integration_range
else:
raise NotImplementedError("Only definite integrals are supported")
integration_vars.append(integration_var)
limits.append((lower_limit, upper_limit))
return integration_vars, limits
class MpmathPrinter(PythonCodePrinter):
"""
Lambda printer for mpmath which maintains precision for floats
"""
printmethod = "_mpmathcode"
language = "Python with mpmath"
_kf = dict(chain(
_known_functions.items(),
[(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()]
))
_kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()}
def _print_Float(self, e):
# XXX: This does not handle setting mpmath.mp.dps. It is assumed that
# the caller of the lambdified function will have set it to sufficient
# precision to match the Floats in the expression.
# Remove 'mpz' if gmpy is installed.
args = str(tuple(map(int, e._mpf_)))
return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args)
def _print_Rational(self, e):
return "{func}({p})/{func}({q})".format(
func=self._module_format('mpmath.mpf'),
q=self._print(e.q),
p=self._print(e.p)
)
def _print_Half(self, e):
return self._print_Rational(e)
def _print_uppergamma(self, e):
return "{}({}, {}, {})".format(
self._module_format('mpmath.gammainc'),
self._print(e.args[0]),
self._print(e.args[1]),
self._module_format('mpmath.inf'))
def _print_lowergamma(self, e):
return "{}({}, 0, {})".format(
self._module_format('mpmath.gammainc'),
self._print(e.args[0]),
self._print(e.args[1]))
def _print_log2(self, e):
return '{0}({1})/{0}(2)'.format(
self._module_format('mpmath.log'), self._print(e.args[0]))
def _print_log1p(self, e):
return '{}({})'.format(
self._module_format('mpmath.log1p'), self._print(e.args[0]))
def _print_Pow(self, expr, rational=False):
return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt')
def _print_Integral(self, e):
integration_vars, limits = _unpack_integral_limits(e)
return "{}(lambda {}: {}, {})".format(
self._module_format("mpmath.quad"),
", ".join(map(self._print, integration_vars)),
self._print(e.args[0]),
", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits))
for k in MpmathPrinter._kf:
setattr(MpmathPrinter, '_print_%s' % k, _print_known_func)
for k in _known_constants_mpmath:
setattr(MpmathPrinter, '_print_%s' % k, _print_known_const)
class SymPyPrinter(AbstractPythonCodePrinter):
language = "Python with SymPy"
_default_settings = dict(
AbstractPythonCodePrinter._default_settings,
strict=False # any class name will per definition be what we target in SymPyPrinter.
)
def _print_Function(self, expr):
mod = expr.func.__module__ or ''
return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__),
', '.join((self._print(arg) for arg in expr.args)))
def _print_Pow(self, expr, rational=False):
return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt')

View File

@ -0,0 +1,92 @@
import keyword as kw
import sympy
from .repr import ReprPrinter
from .str import StrPrinter
# A list of classes that should be printed using StrPrinter
STRPRINT = ("Add", "Infinity", "Integer", "Mul", "NegativeInfinity", "Pow")
class PythonPrinter(ReprPrinter, StrPrinter):
"""A printer which converts an expression into its Python interpretation."""
def __init__(self, settings=None):
super().__init__(settings)
self.symbols = []
self.functions = []
# Create print methods for classes that should use StrPrinter instead
# of ReprPrinter.
for name in STRPRINT:
f_name = "_print_%s" % name
f = getattr(StrPrinter, f_name)
setattr(PythonPrinter, f_name, f)
def _print_Function(self, expr):
func = expr.func.__name__
if not hasattr(sympy, func) and func not in self.functions:
self.functions.append(func)
return StrPrinter._print_Function(self, expr)
# procedure (!) for defining symbols which have be defined in print_python()
def _print_Symbol(self, expr):
symbol = self._str(expr)
if symbol not in self.symbols:
self.symbols.append(symbol)
return StrPrinter._print_Symbol(self, expr)
def _print_module(self, expr):
raise ValueError('Modules in the expression are unacceptable')
def python(expr, **settings):
"""Return Python interpretation of passed expression
(can be passed to the exec() function without any modifications)"""
printer = PythonPrinter(settings)
exprp = printer.doprint(expr)
result = ''
# Returning found symbols and functions
renamings = {}
for symbolname in printer.symbols:
# Remove curly braces from subscripted variables
if '{' in symbolname:
newsymbolname = symbolname.replace('{', '').replace('}', '')
renamings[sympy.Symbol(symbolname)] = newsymbolname
else:
newsymbolname = symbolname
# Escape symbol names that are reserved Python keywords
if kw.iskeyword(newsymbolname):
while True:
newsymbolname += "_"
if (newsymbolname not in printer.symbols and
newsymbolname not in printer.functions):
renamings[sympy.Symbol(
symbolname)] = sympy.Symbol(newsymbolname)
break
result += newsymbolname + ' = Symbol(\'' + symbolname + '\')\n'
for functionname in printer.functions:
newfunctionname = functionname
# Escape function names that are reserved Python keywords
if kw.iskeyword(newfunctionname):
while True:
newfunctionname += "_"
if (newfunctionname not in printer.symbols and
newfunctionname not in printer.functions):
renamings[sympy.Function(
functionname)] = sympy.Function(newfunctionname)
break
result += newfunctionname + ' = Function(\'' + functionname + '\')\n'
if renamings:
exprp = expr.subs(renamings)
result += 'e = ' + printer._str(exprp)
return result
def print_python(expr, **settings):
"""Print output of python() function"""
print(python(expr, **settings))

View File

@ -0,0 +1,405 @@
"""
R code printer
The RCodePrinter converts single SymPy expressions into single R expressions,
using the functions defined in math.h where possible.
"""
from __future__ import annotations
from typing import Any
from sympy.core.numbers import equal_valued
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence, PRECEDENCE
from sympy.sets.fancysets import Range
# dictionary mapping SymPy function to (argument_conditions, C_function).
# Used in RCodePrinter._print_Function(self)
known_functions = {
#"Abs": [(lambda x: not x.is_integer, "fabs")],
"Abs": "abs",
"sin": "sin",
"cos": "cos",
"tan": "tan",
"asin": "asin",
"acos": "acos",
"atan": "atan",
"atan2": "atan2",
"exp": "exp",
"log": "log",
"erf": "erf",
"sinh": "sinh",
"cosh": "cosh",
"tanh": "tanh",
"asinh": "asinh",
"acosh": "acosh",
"atanh": "atanh",
"floor": "floor",
"ceiling": "ceiling",
"sign": "sign",
"Max": "max",
"Min": "min",
"factorial": "factorial",
"gamma": "gamma",
"digamma": "digamma",
"trigamma": "trigamma",
"beta": "beta",
"sqrt": "sqrt", # To enable automatic rewrite
}
# These are the core reserved words in the R language. Taken from:
# https://cran.r-project.org/doc/manuals/r-release/R-lang.html#Reserved-words
reserved_words = ['if',
'else',
'repeat',
'while',
'function',
'for',
'in',
'next',
'break',
'TRUE',
'FALSE',
'NULL',
'Inf',
'NaN',
'NA',
'NA_integer_',
'NA_real_',
'NA_complex_',
'NA_character_',
'volatile']
class RCodePrinter(CodePrinter):
"""A printer to convert SymPy expressions to strings of R code"""
printmethod = "_rcode"
language = "R"
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
'precision': 15,
'user_functions': {},
'contract': True,
'dereference': set(),
})
_operators = {
'and': '&',
'or': '|',
'not': '!',
}
_relationals: dict[str, str] = {}
def __init__(self, settings={}):
CodePrinter.__init__(self, settings)
self.known_functions = dict(known_functions)
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
self._dereference = set(settings.get('dereference', []))
self.reserved_words = set(reserved_words)
def _rate_index_position(self, p):
return p*5
def _get_statement(self, codestring):
return "%s;" % codestring
def _get_comment(self, text):
return "// {}".format(text)
def _declare_number_const(self, name, value):
return "{} = {};".format(name, value)
def _format_code(self, lines):
return self.indent_code(lines)
def _traverse_matrix_indices(self, mat):
rows, cols = mat.shape
return ((i, j) for i in range(rows) for j in range(cols))
def _get_loop_opening_ending(self, indices):
"""Returns a tuple (open_lines, close_lines) containing lists of codelines
"""
open_lines = []
close_lines = []
loopstart = "for (%(var)s in %(start)s:%(end)s){"
for i in indices:
# R arrays start at 1 and end at dimension
open_lines.append(loopstart % {
'var': self._print(i.label),
'start': self._print(i.lower+1),
'end': self._print(i.upper + 1)})
close_lines.append("}")
return open_lines, close_lines
def _print_Pow(self, expr):
if "Pow" in self.known_functions:
return self._print_Function(expr)
PREC = precedence(expr)
if equal_valued(expr.exp, -1):
return '1.0/%s' % (self.parenthesize(expr.base, PREC))
elif equal_valued(expr.exp, 0.5):
return 'sqrt(%s)' % self._print(expr.base)
else:
return '%s^%s' % (self.parenthesize(expr.base, PREC),
self.parenthesize(expr.exp, PREC))
def _print_Rational(self, expr):
p, q = int(expr.p), int(expr.q)
return '%d.0/%d.0' % (p, q)
def _print_Indexed(self, expr):
inds = [ self._print(i) for i in expr.indices ]
return "%s[%s]" % (self._print(expr.base.label), ", ".join(inds))
def _print_Idx(self, expr):
return self._print(expr.label)
def _print_Exp1(self, expr):
return "exp(1)"
def _print_Pi(self, expr):
return 'pi'
def _print_Infinity(self, expr):
return 'Inf'
def _print_NegativeInfinity(self, expr):
return '-Inf'
def _print_Assignment(self, expr):
from sympy.codegen.ast import Assignment
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.tensor.indexed import IndexedBase
lhs = expr.lhs
rhs = expr.rhs
# We special case assignments that take multiple lines
#if isinstance(expr.rhs, Piecewise):
# from sympy.functions.elementary.piecewise import Piecewise
# # Here we modify Piecewise so each expression is now
# # an Assignment, and then continue on the print.
# expressions = []
# conditions = []
# for (e, c) in rhs.args:
# expressions.append(Assignment(lhs, e))
# conditions.append(c)
# temp = Piecewise(*zip(expressions, conditions))
# return self._print(temp)
#elif isinstance(lhs, MatrixSymbol):
if isinstance(lhs, MatrixSymbol):
# Here we form an Assignment for each element in the array,
# printing each one.
lines = []
for (i, j) in self._traverse_matrix_indices(lhs):
temp = Assignment(lhs[i, j], rhs[i, j])
code0 = self._print(temp)
lines.append(code0)
return "\n".join(lines)
elif self._settings["contract"] and (lhs.has(IndexedBase) or
rhs.has(IndexedBase)):
# Here we check if there is looping to be done, and if so
# print the required loops.
return self._doprint_loops(rhs, lhs)
else:
lhs_code = self._print(lhs)
rhs_code = self._print(rhs)
return self._get_statement("%s = %s" % (lhs_code, rhs_code))
def _print_Piecewise(self, expr):
# This method is called only for inline if constructs
# Top level piecewise is handled in doprint()
if expr.args[-1].cond == True:
last_line = "%s" % self._print(expr.args[-1].expr)
else:
last_line = "ifelse(%s,%s,NA)" % (self._print(expr.args[-1].cond), self._print(expr.args[-1].expr))
code=last_line
for e, c in reversed(expr.args[:-1]):
code= "ifelse(%s,%s," % (self._print(c), self._print(e))+code+")"
return(code)
def _print_ITE(self, expr):
from sympy.functions import Piecewise
return self._print(expr.rewrite(Piecewise))
def _print_MatrixElement(self, expr):
return "{}[{}]".format(self.parenthesize(expr.parent, PRECEDENCE["Atom"],
strict=True), expr.j + expr.i*expr.parent.shape[1])
def _print_Symbol(self, expr):
name = super()._print_Symbol(expr)
if expr in self._dereference:
return '(*{})'.format(name)
else:
return name
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
return "{} {} {}".format(lhs_code, op, rhs_code)
def _print_AugmentedAssignment(self, expr):
lhs_code = self._print(expr.lhs)
op = expr.op
rhs_code = self._print(expr.rhs)
return "{} {} {};".format(lhs_code, op, rhs_code)
def _print_For(self, expr):
target = self._print(expr.target)
if isinstance(expr.iterable, Range):
start, stop, step = expr.iterable.args
else:
raise NotImplementedError("Only iterable currently supported is Range")
body = self._print(expr.body)
return 'for({target} in seq(from={start}, to={stop}, by={step}){{\n{body}\n}}'.format(target=target, start=start,
stop=stop-1, step=step, body=body)
def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
if isinstance(code, str):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
tab = " "
inc_token = ('{', '(', '{\n', '(\n')
dec_token = ('}', ')')
code = [ line.lstrip(' \t') for line in code ]
increase = [ int(any(map(line.endswith, inc_token))) for line in code ]
decrease = [ int(any(map(line.startswith, dec_token)))
for line in code ]
pretty = []
level = 0
for n, line in enumerate(code):
if line in ('', '\n'):
pretty.append(line)
continue
level -= decrease[n]
pretty.append("%s%s" % (tab*level, line))
level += increase[n]
return pretty
def rcode(expr, assign_to=None, **settings):
"""Converts an expr to a string of r code
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
line-wrapping, or for expressions that generate multi-line statements.
precision : integer, optional
The precision for numbers such as pi [default=15].
user_functions : dict, optional
A dictionary where the keys are string representations of either
``FunctionClass`` or ``UndefinedFunction`` instances and the values
are their desired R string representations. Alternatively, the
dictionary value can be a list of tuples i.e. [(argument_test,
rfunction_string)] or [(argument_test, rfunction_formater)]. See below
for examples.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
Examples
========
>>> from sympy import rcode, symbols, Rational, sin, ceiling, Abs, Function
>>> x, tau = symbols("x, tau")
>>> rcode((2*tau)**Rational(7, 2))
'8*sqrt(2)*tau^(7.0/2.0)'
>>> rcode(sin(x), assign_to="s")
's = sin(x);'
Simple custom printing can be defined for certain types by passing a
dictionary of {"type" : "function"} to the ``user_functions`` kwarg.
Alternatively, the dictionary value can be a list of tuples i.e.
[(argument_test, cfunction_string)].
>>> custom_functions = {
... "ceiling": "CEIL",
... "Abs": [(lambda x: not x.is_integer, "fabs"),
... (lambda x: x.is_integer, "ABS")],
... "func": "f"
... }
>>> func = Function('func')
>>> rcode(func(Abs(x) + ceiling(x)), user_functions=custom_functions)
'f(fabs(x) + CEIL(x))'
or if the R-function takes a subset of the original arguments:
>>> rcode(2**x + 3**x, user_functions={'Pow': [
... (lambda b, e: b == 2, lambda b, e: 'exp2(%s)' % e),
... (lambda b, e: b != 2, 'pow')]})
'exp2(x) + pow(3, x)'
``Piecewise`` expressions are converted into conditionals. If an
``assign_to`` variable is provided an if statement is created, otherwise
the ternary operator is used. Note that if the ``Piecewise`` lacks a
default term, represented by ``(expr, True)`` then an error will be thrown.
This is to prevent generating an expression that may not evaluate to
anything.
>>> from sympy import Piecewise
>>> expr = Piecewise((x + 1, x > 0), (x, True))
>>> print(rcode(expr, assign_to=tau))
tau = ifelse(x > 0,x + 1,x);
Support for loops is provided through ``Indexed`` types. With
``contract=True`` these expressions will be turned into loops, whereas
``contract=False`` will just print the assignment expression that should be
looped over:
>>> from sympy import Eq, IndexedBase, Idx
>>> len_y = 5
>>> y = IndexedBase('y', shape=(len_y,))
>>> t = IndexedBase('t', shape=(len_y,))
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
>>> i = Idx('i', len_y-1)
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
>>> rcode(e.rhs, assign_to=e.lhs, contract=False)
'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
must be provided to ``assign_to``. Note that any expression that can be
generated normally can also exist inside a Matrix:
>>> from sympy import Matrix, MatrixSymbol
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
>>> A = MatrixSymbol('A', 3, 1)
>>> print(rcode(mat, A))
A[0] = x^2;
A[1] = ifelse(x > 0,x + 1,x);
A[2] = sin(x);
"""
return RCodePrinter(settings).doprint(expr, assign_to)
def print_rcode(expr, **settings):
"""Prints R representation of the given expression."""
print(rcode(expr, **settings))

View File

@ -0,0 +1,338 @@
"""
A Printer for generating executable code.
The most important function here is srepr that returns a string so that the
relation eval(srepr(expr))=expr holds in an appropriate environment.
"""
from __future__ import annotations
from typing import Any
from sympy.core.function import AppliedUndef
from sympy.core.mul import Mul
from mpmath.libmp import repr_dps, to_str as mlib_to_str
from .printer import Printer, print_function
class ReprPrinter(Printer):
printmethod = "_sympyrepr"
_default_settings: dict[str, Any] = {
"order": None,
"perm_cyclic" : True,
}
def reprify(self, args, sep):
"""
Prints each item in `args` and joins them with `sep`.
"""
return sep.join([self.doprint(item) for item in args])
def emptyPrinter(self, expr):
"""
The fallback printer.
"""
if isinstance(expr, str):
return expr
elif hasattr(expr, "__srepr__"):
return expr.__srepr__()
elif hasattr(expr, "args") and hasattr(expr.args, "__iter__"):
l = []
for o in expr.args:
l.append(self._print(o))
return expr.__class__.__name__ + '(%s)' % ', '.join(l)
elif hasattr(expr, "__module__") and hasattr(expr, "__name__"):
return "<'%s.%s'>" % (expr.__module__, expr.__name__)
else:
return str(expr)
def _print_Add(self, expr, order=None):
args = self._as_ordered_terms(expr, order=order)
args = map(self._print, args)
clsname = type(expr).__name__
return clsname + "(%s)" % ", ".join(args)
def _print_Cycle(self, expr):
return expr.__repr__()
def _print_Permutation(self, expr):
from sympy.combinatorics.permutations import Permutation, Cycle
from sympy.utilities.exceptions import sympy_deprecation_warning
perm_cyclic = Permutation.print_cyclic
if perm_cyclic is not None:
sympy_deprecation_warning(
f"""
Setting Permutation.print_cyclic is deprecated. Instead use
init_printing(perm_cyclic={perm_cyclic}).
""",
deprecated_since_version="1.6",
active_deprecations_target="deprecated-permutation-print_cyclic",
stacklevel=7,
)
else:
perm_cyclic = self._settings.get("perm_cyclic", True)
if perm_cyclic:
if not expr.size:
return 'Permutation()'
# before taking Cycle notation, see if the last element is
# a singleton and move it to the head of the string
s = Cycle(expr)(expr.size - 1).__repr__()[len('Cycle'):]
last = s.rfind('(')
if not last == 0 and ',' not in s[last:]:
s = s[last:] + s[:last]
return 'Permutation%s' %s
else:
s = expr.support()
if not s:
if expr.size < 5:
return 'Permutation(%s)' % str(expr.array_form)
return 'Permutation([], size=%s)' % expr.size
trim = str(expr.array_form[:s[-1] + 1]) + ', size=%s' % expr.size
use = full = str(expr.array_form)
if len(trim) < len(full):
use = trim
return 'Permutation(%s)' % use
def _print_Function(self, expr):
r = self._print(expr.func)
r += '(%s)' % ', '.join([self._print(a) for a in expr.args])
return r
def _print_Heaviside(self, expr):
# Same as _print_Function but uses pargs to suppress default value for
# 2nd arg.
r = self._print(expr.func)
r += '(%s)' % ', '.join([self._print(a) for a in expr.pargs])
return r
def _print_FunctionClass(self, expr):
if issubclass(expr, AppliedUndef):
return 'Function(%r)' % (expr.__name__)
else:
return expr.__name__
def _print_Half(self, expr):
return 'Rational(1, 2)'
def _print_RationalConstant(self, expr):
return str(expr)
def _print_AtomicExpr(self, expr):
return str(expr)
def _print_NumberSymbol(self, expr):
return str(expr)
def _print_Integer(self, expr):
return 'Integer(%i)' % expr.p
def _print_Complexes(self, expr):
return 'Complexes'
def _print_Integers(self, expr):
return 'Integers'
def _print_Naturals(self, expr):
return 'Naturals'
def _print_Naturals0(self, expr):
return 'Naturals0'
def _print_Rationals(self, expr):
return 'Rationals'
def _print_Reals(self, expr):
return 'Reals'
def _print_EmptySet(self, expr):
return 'EmptySet'
def _print_UniversalSet(self, expr):
return 'UniversalSet'
def _print_EmptySequence(self, expr):
return 'EmptySequence'
def _print_list(self, expr):
return "[%s]" % self.reprify(expr, ", ")
def _print_dict(self, expr):
sep = ", "
dict_kvs = ["%s: %s" % (self.doprint(key), self.doprint(value)) for key, value in expr.items()]
return "{%s}" % sep.join(dict_kvs)
def _print_set(self, expr):
if not expr:
return "set()"
return "{%s}" % self.reprify(expr, ", ")
def _print_MatrixBase(self, expr):
# special case for some empty matrices
if (expr.rows == 0) ^ (expr.cols == 0):
return '%s(%s, %s, %s)' % (expr.__class__.__name__,
self._print(expr.rows),
self._print(expr.cols),
self._print([]))
l = []
for i in range(expr.rows):
l.append([])
for j in range(expr.cols):
l[-1].append(expr[i, j])
return '%s(%s)' % (expr.__class__.__name__, self._print(l))
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
def _print_NaN(self, expr):
return "nan"
def _print_Mul(self, expr, order=None):
if self.order not in ('old', 'none'):
args = expr.as_ordered_factors()
else:
# use make_args in case expr was something like -x -> x
args = Mul.make_args(expr)
args = map(self._print, args)
clsname = type(expr).__name__
return clsname + "(%s)" % ", ".join(args)
def _print_Rational(self, expr):
return 'Rational(%s, %s)' % (self._print(expr.p), self._print(expr.q))
def _print_PythonRational(self, expr):
return "%s(%d, %d)" % (expr.__class__.__name__, expr.p, expr.q)
def _print_Fraction(self, expr):
return 'Fraction(%s, %s)' % (self._print(expr.numerator), self._print(expr.denominator))
def _print_Float(self, expr):
r = mlib_to_str(expr._mpf_, repr_dps(expr._prec))
return "%s('%s', precision=%i)" % (expr.__class__.__name__, r, expr._prec)
def _print_Sum2(self, expr):
return "Sum2(%s, (%s, %s, %s))" % (self._print(expr.f), self._print(expr.i),
self._print(expr.a), self._print(expr.b))
def _print_Str(self, s):
return "%s(%s)" % (s.__class__.__name__, self._print(s.name))
def _print_Symbol(self, expr):
d = expr._assumptions_orig
# print the dummy_index like it was an assumption
if expr.is_Dummy:
d['dummy_index'] = expr.dummy_index
if d == {}:
return "%s(%s)" % (expr.__class__.__name__, self._print(expr.name))
else:
attr = ['%s=%s' % (k, v) for k, v in d.items()]
return "%s(%s, %s)" % (expr.__class__.__name__,
self._print(expr.name), ', '.join(attr))
def _print_CoordinateSymbol(self, expr):
d = expr._assumptions.generator
if d == {}:
return "%s(%s, %s)" % (
expr.__class__.__name__,
self._print(expr.coord_sys),
self._print(expr.index)
)
else:
attr = ['%s=%s' % (k, v) for k, v in d.items()]
return "%s(%s, %s, %s)" % (
expr.__class__.__name__,
self._print(expr.coord_sys),
self._print(expr.index),
', '.join(attr)
)
def _print_Predicate(self, expr):
return "Q.%s" % expr.name
def _print_AppliedPredicate(self, expr):
# will be changed to just expr.args when args overriding is removed
args = expr._args
return "%s(%s)" % (expr.__class__.__name__, self.reprify(args, ", "))
def _print_str(self, expr):
return repr(expr)
def _print_tuple(self, expr):
if len(expr) == 1:
return "(%s,)" % self._print(expr[0])
else:
return "(%s)" % self.reprify(expr, ", ")
def _print_WildFunction(self, expr):
return "%s('%s')" % (expr.__class__.__name__, expr.name)
def _print_AlgebraicNumber(self, expr):
return "%s(%s, %s)" % (expr.__class__.__name__,
self._print(expr.root), self._print(expr.coeffs()))
def _print_PolyRing(self, ring):
return "%s(%s, %s, %s)" % (ring.__class__.__name__,
self._print(ring.symbols), self._print(ring.domain), self._print(ring.order))
def _print_FracField(self, field):
return "%s(%s, %s, %s)" % (field.__class__.__name__,
self._print(field.symbols), self._print(field.domain), self._print(field.order))
def _print_PolyElement(self, poly):
terms = list(poly.terms())
terms.sort(key=poly.ring.order, reverse=True)
return "%s(%s, %s)" % (poly.__class__.__name__, self._print(poly.ring), self._print(terms))
def _print_FracElement(self, frac):
numer_terms = list(frac.numer.terms())
numer_terms.sort(key=frac.field.order, reverse=True)
denom_terms = list(frac.denom.terms())
denom_terms.sort(key=frac.field.order, reverse=True)
numer = self._print(numer_terms)
denom = self._print(denom_terms)
return "%s(%s, %s, %s)" % (frac.__class__.__name__, self._print(frac.field), numer, denom)
def _print_FractionField(self, domain):
cls = domain.__class__.__name__
field = self._print(domain.field)
return "%s(%s)" % (cls, field)
def _print_PolynomialRingBase(self, ring):
cls = ring.__class__.__name__
dom = self._print(ring.domain)
gens = ', '.join(map(self._print, ring.gens))
order = str(ring.order)
if order != ring.default_order:
orderstr = ", order=" + order
else:
orderstr = ""
return "%s(%s, %s%s)" % (cls, dom, gens, orderstr)
def _print_DMP(self, p):
cls = p.__class__.__name__
rep = self._print(p.to_list())
dom = self._print(p.dom)
return "%s(%s, %s)" % (cls, rep, dom)
def _print_MonogenicFiniteExtension(self, ext):
# The expanded tree shown by srepr(ext.modulus)
# is not practical.
return "FiniteExtension(%s)" % str(ext.modulus)
def _print_ExtensionElement(self, f):
rep = self._print(f.rep)
ext = self._print(f.ext)
return "ExtElem(%s, %s)" % (rep, ext)
@print_function(ReprPrinter)
def srepr(expr, **settings):
"""return expr in repr form"""
return ReprPrinter(settings).doprint(expr)

View File

@ -0,0 +1,619 @@
"""
Rust code printer
The `RustCodePrinter` converts SymPy expressions into Rust expressions.
A complete code generator, which uses `rust_code` extensively, can be found
in `sympy.utilities.codegen`. The `codegen` module can be used to generate
complete source code files.
"""
# Possible Improvement
#
# * make sure we follow Rust Style Guidelines_
# * make use of pattern matching
# * better support for reference
# * generate generic code and use trait to make sure they have specific methods
# * use crates_ to get more math support
# - num_
# + BigInt_, BigUint_
# + Complex_
# + Rational64_, Rational32_, BigRational_
#
# .. _crates: https://crates.io/
# .. _Guidelines: https://github.com/rust-lang/rust/tree/master/src/doc/style
# .. _num: http://rust-num.github.io/num/num/
# .. _BigInt: http://rust-num.github.io/num/num/bigint/struct.BigInt.html
# .. _BigUint: http://rust-num.github.io/num/num/bigint/struct.BigUint.html
# .. _Complex: http://rust-num.github.io/num/num/complex/struct.Complex.html
# .. _Rational32: http://rust-num.github.io/num/num/rational/type.Rational32.html
# .. _Rational64: http://rust-num.github.io/num/num/rational/type.Rational64.html
# .. _BigRational: http://rust-num.github.io/num/num/rational/type.BigRational.html
from __future__ import annotations
from typing import Any
from sympy.core import S, Rational, Float, Lambda
from sympy.core.numbers import equal_valued
from sympy.printing.codeprinter import CodePrinter
# Rust's methods for integer and float can be found at here :
#
# * `Rust - Primitive Type f64 <https://doc.rust-lang.org/std/primitive.f64.html>`_
# * `Rust - Primitive Type i64 <https://doc.rust-lang.org/std/primitive.i64.html>`_
#
# Function Style :
#
# 1. args[0].func(args[1:]), method with arguments
# 2. args[0].func(), method without arguments
# 3. args[1].func(), method without arguments (e.g. (e, x) => x.exp())
# 4. func(args), function with arguments
# dictionary mapping SymPy function to (argument_conditions, Rust_function).
# Used in RustCodePrinter._print_Function(self)
# f64 method in Rust
known_functions = {
# "": "is_nan",
# "": "is_infinite",
# "": "is_finite",
# "": "is_normal",
# "": "classify",
"floor": "floor",
"ceiling": "ceil",
# "": "round",
# "": "trunc",
# "": "fract",
"Abs": "abs",
"sign": "signum",
# "": "is_sign_positive",
# "": "is_sign_negative",
# "": "mul_add",
"Pow": [(lambda base, exp: equal_valued(exp, -1), "recip", 2), # 1.0/x
(lambda base, exp: equal_valued(exp, 0.5), "sqrt", 2), # x ** 0.5
(lambda base, exp: equal_valued(exp, -0.5), "sqrt().recip", 2), # 1/(x ** 0.5)
(lambda base, exp: exp == Rational(1, 3), "cbrt", 2), # x ** (1/3)
(lambda base, exp: equal_valued(base, 2), "exp2", 3), # 2 ** x
(lambda base, exp: exp.is_integer, "powi", 1), # x ** y, for i32
(lambda base, exp: not exp.is_integer, "powf", 1)], # x ** y, for f64
"exp": [(lambda exp: True, "exp", 2)], # e ** x
"log": "ln",
# "": "log", # number.log(base)
# "": "log2",
# "": "log10",
# "": "to_degrees",
# "": "to_radians",
"Max": "max",
"Min": "min",
# "": "hypot", # (x**2 + y**2) ** 0.5
"sin": "sin",
"cos": "cos",
"tan": "tan",
"asin": "asin",
"acos": "acos",
"atan": "atan",
"atan2": "atan2",
# "": "sin_cos",
# "": "exp_m1", # e ** x - 1
# "": "ln_1p", # ln(1 + x)
"sinh": "sinh",
"cosh": "cosh",
"tanh": "tanh",
"asinh": "asinh",
"acosh": "acosh",
"atanh": "atanh",
"sqrt": "sqrt", # To enable automatic rewrites
}
# i64 method in Rust
# known_functions_i64 = {
# "": "min_value",
# "": "max_value",
# "": "from_str_radix",
# "": "count_ones",
# "": "count_zeros",
# "": "leading_zeros",
# "": "trainling_zeros",
# "": "rotate_left",
# "": "rotate_right",
# "": "swap_bytes",
# "": "from_be",
# "": "from_le",
# "": "to_be", # to big endian
# "": "to_le", # to little endian
# "": "checked_add",
# "": "checked_sub",
# "": "checked_mul",
# "": "checked_div",
# "": "checked_rem",
# "": "checked_neg",
# "": "checked_shl",
# "": "checked_shr",
# "": "checked_abs",
# "": "saturating_add",
# "": "saturating_sub",
# "": "saturating_mul",
# "": "wrapping_add",
# "": "wrapping_sub",
# "": "wrapping_mul",
# "": "wrapping_div",
# "": "wrapping_rem",
# "": "wrapping_neg",
# "": "wrapping_shl",
# "": "wrapping_shr",
# "": "wrapping_abs",
# "": "overflowing_add",
# "": "overflowing_sub",
# "": "overflowing_mul",
# "": "overflowing_div",
# "": "overflowing_rem",
# "": "overflowing_neg",
# "": "overflowing_shl",
# "": "overflowing_shr",
# "": "overflowing_abs",
# "Pow": "pow",
# "Abs": "abs",
# "sign": "signum",
# "": "is_positive",
# "": "is_negnative",
# }
# These are the core reserved words in the Rust language. Taken from:
# http://doc.rust-lang.org/grammar.html#keywords
reserved_words = ['abstract',
'alignof',
'as',
'become',
'box',
'break',
'const',
'continue',
'crate',
'do',
'else',
'enum',
'extern',
'false',
'final',
'fn',
'for',
'if',
'impl',
'in',
'let',
'loop',
'macro',
'match',
'mod',
'move',
'mut',
'offsetof',
'override',
'priv',
'proc',
'pub',
'pure',
'ref',
'return',
'Self',
'self',
'sizeof',
'static',
'struct',
'super',
'trait',
'true',
'type',
'typeof',
'unsafe',
'unsized',
'use',
'virtual',
'where',
'while',
'yield']
class RustCodePrinter(CodePrinter):
"""A printer to convert SymPy expressions to strings of Rust code"""
printmethod = "_rust_code"
language = "Rust"
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
'precision': 17,
'user_functions': {},
'contract': True,
'dereference': set(),
})
def __init__(self, settings={}):
CodePrinter.__init__(self, settings)
self.known_functions = dict(known_functions)
userfuncs = settings.get('user_functions', {})
self.known_functions.update(userfuncs)
self._dereference = set(settings.get('dereference', []))
self.reserved_words = set(reserved_words)
def _rate_index_position(self, p):
return p*5
def _get_statement(self, codestring):
return "%s;" % codestring
def _get_comment(self, text):
return "// %s" % text
def _declare_number_const(self, name, value):
return "const %s: f64 = %s;" % (name, value)
def _format_code(self, lines):
return self.indent_code(lines)
def _traverse_matrix_indices(self, mat):
rows, cols = mat.shape
return ((i, j) for i in range(rows) for j in range(cols))
def _get_loop_opening_ending(self, indices):
open_lines = []
close_lines = []
loopstart = "for %(var)s in %(start)s..%(end)s {"
for i in indices:
# Rust arrays start at 0 and end at dimension-1
open_lines.append(loopstart % {
'var': self._print(i),
'start': self._print(i.lower),
'end': self._print(i.upper + 1)})
close_lines.append("}")
return open_lines, close_lines
def _print_caller_var(self, expr):
if len(expr.args) > 1:
# for something like `sin(x + y + z)`,
# make sure we can get '(x + y + z).sin()'
# instead of 'x + y + z.sin()'
return '(' + self._print(expr) + ')'
elif expr.is_number:
return self._print(expr, _type=True)
else:
return self._print(expr)
def _print_Function(self, expr):
"""
basic function for printing `Function`
Function Style :
1. args[0].func(args[1:]), method with arguments
2. args[0].func(), method without arguments
3. args[1].func(), method without arguments (e.g. (e, x) => x.exp())
4. func(args), function with arguments
"""
if expr.func.__name__ in self.known_functions:
cond_func = self.known_functions[expr.func.__name__]
func = None
style = 1
if isinstance(cond_func, str):
func = cond_func
else:
for cond, func, style in cond_func:
if cond(*expr.args):
break
if func is not None:
if style == 1:
ret = "%(var)s.%(method)s(%(args)s)" % {
'var': self._print_caller_var(expr.args[0]),
'method': func,
'args': self.stringify(expr.args[1:], ", ") if len(expr.args) > 1 else ''
}
elif style == 2:
ret = "%(var)s.%(method)s()" % {
'var': self._print_caller_var(expr.args[0]),
'method': func,
}
elif style == 3:
ret = "%(var)s.%(method)s()" % {
'var': self._print_caller_var(expr.args[1]),
'method': func,
}
else:
ret = "%(func)s(%(args)s)" % {
'func': func,
'args': self.stringify(expr.args, ", "),
}
return ret
elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda):
# inlined function
return self._print(expr._imp_(*expr.args))
elif expr.func.__name__ in self._rewriteable_functions:
# Simple rewrite to supported function possible
target_f, required_fs = self._rewriteable_functions[expr.func.__name__]
if self._can_print(target_f) and all(self._can_print(f) for f in required_fs):
return self._print(expr.rewrite(target_f))
else:
return self._print_not_supported(expr)
def _print_Pow(self, expr):
if expr.base.is_integer and not expr.exp.is_integer:
expr = type(expr)(Float(expr.base), expr.exp)
return self._print(expr)
return self._print_Function(expr)
def _print_Float(self, expr, _type=False):
ret = super()._print_Float(expr)
if _type:
return ret + '_f64'
else:
return ret
def _print_Integer(self, expr, _type=False):
ret = super()._print_Integer(expr)
if _type:
return ret + '_i32'
else:
return ret
def _print_Rational(self, expr):
p, q = int(expr.p), int(expr.q)
return '%d_f64/%d.0' % (p, q)
def _print_Relational(self, expr):
lhs_code = self._print(expr.lhs)
rhs_code = self._print(expr.rhs)
op = expr.rel_op
return "{} {} {}".format(lhs_code, op, rhs_code)
def _print_Indexed(self, expr):
# calculate index for 1d array
dims = expr.shape
elem = S.Zero
offset = S.One
for i in reversed(range(expr.rank)):
elem += expr.indices[i]*offset
offset *= dims[i]
return "%s[%s]" % (self._print(expr.base.label), self._print(elem))
def _print_Idx(self, expr):
return expr.label.name
def _print_Dummy(self, expr):
return expr.name
def _print_Exp1(self, expr, _type=False):
return "E"
def _print_Pi(self, expr, _type=False):
return 'PI'
def _print_Infinity(self, expr, _type=False):
return 'INFINITY'
def _print_NegativeInfinity(self, expr, _type=False):
return 'NEG_INFINITY'
def _print_BooleanTrue(self, expr, _type=False):
return "true"
def _print_BooleanFalse(self, expr, _type=False):
return "false"
def _print_bool(self, expr, _type=False):
return str(expr).lower()
def _print_NaN(self, expr, _type=False):
return "NAN"
def _print_Piecewise(self, expr):
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
# function may not return a result.
raise ValueError("All Piecewise expressions must contain an "
"(expr, True) statement to be used as a default "
"condition. Without one, the generated "
"expression may not evaluate to anything under "
"some condition.")
lines = []
for i, (e, c) in enumerate(expr.args):
if i == 0:
lines.append("if (%s) {" % self._print(c))
elif i == len(expr.args) - 1 and c == True:
lines[-1] += " else {"
else:
lines[-1] += " else if (%s) {" % self._print(c)
code0 = self._print(e)
lines.append(code0)
lines.append("}")
if self._settings['inline']:
return " ".join(lines)
else:
return "\n".join(lines)
def _print_ITE(self, expr):
from sympy.functions import Piecewise
return self._print(expr.rewrite(Piecewise, deep=False))
def _print_MatrixBase(self, A):
if A.cols == 1:
return "[%s]" % ", ".join(self._print(a) for a in A)
else:
raise ValueError("Full Matrix Support in Rust need Crates (https://crates.io/keywords/matrix).")
def _print_SparseRepMatrix(self, mat):
# do not allow sparse matrices to be made dense
return self._print_not_supported(mat)
def _print_MatrixElement(self, expr):
return "%s[%s]" % (expr.parent,
expr.j + expr.i*expr.parent.shape[1])
def _print_Symbol(self, expr):
name = super()._print_Symbol(expr)
if expr in self._dereference:
return '(*%s)' % name
else:
return name
def _print_Assignment(self, expr):
from sympy.tensor.indexed import IndexedBase
lhs = expr.lhs
rhs = expr.rhs
if self._settings["contract"] and (lhs.has(IndexedBase) or
rhs.has(IndexedBase)):
# Here we check if there is looping to be done, and if so
# print the required loops.
return self._doprint_loops(rhs, lhs)
else:
lhs_code = self._print(lhs)
rhs_code = self._print(rhs)
return self._get_statement("%s = %s" % (lhs_code, rhs_code))
def indent_code(self, code):
"""Accepts a string of code or a list of code lines"""
if isinstance(code, str):
code_lines = self.indent_code(code.splitlines(True))
return ''.join(code_lines)
tab = " "
inc_token = ('{', '(', '{\n', '(\n')
dec_token = ('}', ')')
code = [ line.lstrip(' \t') for line in code ]
increase = [ int(any(map(line.endswith, inc_token))) for line in code ]
decrease = [ int(any(map(line.startswith, dec_token)))
for line in code ]
pretty = []
level = 0
for n, line in enumerate(code):
if line in ('', '\n'):
pretty.append(line)
continue
level -= decrease[n]
pretty.append("%s%s" % (tab*level, line))
level += increase[n]
return pretty
def rust_code(expr, assign_to=None, **settings):
"""Converts an expr to a string of Rust code
Parameters
==========
expr : Expr
A SymPy expression to be converted.
assign_to : optional
When given, the argument is used as the name of the variable to which
the expression is assigned. Can be a string, ``Symbol``,
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
line-wrapping, or for expressions that generate multi-line statements.
precision : integer, optional
The precision for numbers such as pi [default=15].
user_functions : dict, optional
A dictionary where the keys are string representations of either
``FunctionClass`` or ``UndefinedFunction`` instances and the values
are their desired C string representations. Alternatively, the
dictionary value can be a list of tuples i.e. [(argument_test,
cfunction_string)]. See below for examples.
dereference : iterable, optional
An iterable of symbols that should be dereferenced in the printed code
expression. These would be values passed by address to the function.
For example, if ``dereference=[a]``, the resulting code would print
``(*a)`` instead of ``a``.
human : bool, optional
If True, the result is a single string that may contain some constant
declarations for the number symbols. If False, the same information is
returned in a tuple of (symbols_to_declare, not_supported_functions,
code_text). [default=True].
contract: bool, optional
If True, ``Indexed`` instances are assumed to obey tensor contraction
rules and the corresponding nested loops over indices are generated.
Setting contract=False will not generate loops, instead the user is
responsible to provide values for the indices in the code.
[default=True].
Examples
========
>>> from sympy import rust_code, symbols, Rational, sin, ceiling, Abs, Function
>>> x, tau = symbols("x, tau")
>>> rust_code((2*tau)**Rational(7, 2))
'8*1.4142135623731*tau.powf(7_f64/2.0)'
>>> rust_code(sin(x), assign_to="s")
's = x.sin();'
Simple custom printing can be defined for certain types by passing a
dictionary of {"type" : "function"} to the ``user_functions`` kwarg.
Alternatively, the dictionary value can be a list of tuples i.e.
[(argument_test, cfunction_string)].
>>> custom_functions = {
... "ceiling": "CEIL",
... "Abs": [(lambda x: not x.is_integer, "fabs", 4),
... (lambda x: x.is_integer, "ABS", 4)],
... "func": "f"
... }
>>> func = Function('func')
>>> rust_code(func(Abs(x) + ceiling(x)), user_functions=custom_functions)
'(fabs(x) + x.CEIL()).f()'
``Piecewise`` expressions are converted into conditionals. If an
``assign_to`` variable is provided an if statement is created, otherwise
the ternary operator is used. Note that if the ``Piecewise`` lacks a
default term, represented by ``(expr, True)`` then an error will be thrown.
This is to prevent generating an expression that may not evaluate to
anything.
>>> from sympy import Piecewise
>>> expr = Piecewise((x + 1, x > 0), (x, True))
>>> print(rust_code(expr, tau))
tau = if (x > 0) {
x + 1
} else {
x
};
Support for loops is provided through ``Indexed`` types. With
``contract=True`` these expressions will be turned into loops, whereas
``contract=False`` will just print the assignment expression that should be
looped over:
>>> from sympy import Eq, IndexedBase, Idx
>>> len_y = 5
>>> y = IndexedBase('y', shape=(len_y,))
>>> t = IndexedBase('t', shape=(len_y,))
>>> Dy = IndexedBase('Dy', shape=(len_y-1,))
>>> i = Idx('i', len_y-1)
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
>>> rust_code(e.rhs, assign_to=e.lhs, contract=False)
'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
must be provided to ``assign_to``. Note that any expression that can be
generated normally can also exist inside a Matrix:
>>> from sympy import Matrix, MatrixSymbol
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
>>> A = MatrixSymbol('A', 3, 1)
>>> print(rust_code(mat, A))
A = [x.powi(2), if (x > 0) {
x + 1
} else {
x
}, x.sin()];
"""
return RustCodePrinter(settings).doprint(expr, assign_to)
def print_rust_code(expr, **settings):
"""Prints Rust representation of the given expression."""
print(rust_code(expr, **settings))

View File

@ -0,0 +1,583 @@
import typing
import sympy
from sympy.core import Add, Mul
from sympy.core import Symbol, Expr, Float, Rational, Integer, Basic
from sympy.core.function import UndefinedFunction, Function
from sympy.core.relational import Relational, Unequality, Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import exp, log, Pow
from sympy.functions.elementary.hyperbolic import sinh, cosh, tanh
from sympy.functions.elementary.miscellaneous import Min, Max
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import sin, cos, tan, asin, acos, atan, atan2
from sympy.logic.boolalg import And, Or, Xor, Implies, Boolean
from sympy.logic.boolalg import BooleanTrue, BooleanFalse, BooleanFunction, Not, ITE
from sympy.printing.printer import Printer
from sympy.sets import Interval
from mpmath.libmp.libmpf import prec_to_dps, to_str as mlib_to_str
from sympy.assumptions.assume import AppliedPredicate
from sympy.assumptions.relation.binrel import AppliedBinaryRelation
from sympy.assumptions.ask import Q
from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate
class SMTLibPrinter(Printer):
printmethod = "_smtlib"
# based on dReal, an automated reasoning tool for solving problems that can be encoded as first-order logic formulas over the real numbers.
# dReal's special strength is in handling problems that involve a wide range of nonlinear real functions.
_default_settings: dict = {
'precision': None,
'known_types': {
bool: 'Bool',
int: 'Int',
float: 'Real'
},
'known_constants': {
# pi: 'MY_VARIABLE_PI_DECLARED_ELSEWHERE',
},
'known_functions': {
Add: '+',
Mul: '*',
Equality: '=',
LessThan: '<=',
GreaterThan: '>=',
StrictLessThan: '<',
StrictGreaterThan: '>',
EqualityPredicate(): '=',
LessThanPredicate(): '<=',
GreaterThanPredicate(): '>=',
StrictLessThanPredicate(): '<',
StrictGreaterThanPredicate(): '>',
exp: 'exp',
log: 'log',
Abs: 'abs',
sin: 'sin',
cos: 'cos',
tan: 'tan',
asin: 'arcsin',
acos: 'arccos',
atan: 'arctan',
atan2: 'arctan2',
sinh: 'sinh',
cosh: 'cosh',
tanh: 'tanh',
Min: 'min',
Max: 'max',
Pow: 'pow',
And: 'and',
Or: 'or',
Xor: 'xor',
Not: 'not',
ITE: 'ite',
Implies: '=>',
}
}
symbol_table: dict
def __init__(self, settings: typing.Optional[dict] = None,
symbol_table=None):
settings = settings or {}
self.symbol_table = symbol_table or {}
Printer.__init__(self, settings)
self._precision = self._settings['precision']
self._known_types = dict(self._settings['known_types'])
self._known_constants = dict(self._settings['known_constants'])
self._known_functions = dict(self._settings['known_functions'])
for _ in self._known_types.values(): assert self._is_legal_name(_)
for _ in self._known_constants.values(): assert self._is_legal_name(_)
# for _ in self._known_functions.values(): assert self._is_legal_name(_) # +, *, <, >, etc.
def _is_legal_name(self, s: str):
if not s: return False
if s[0].isnumeric(): return False
return all(_.isalnum() or _ == '_' for _ in s)
def _s_expr(self, op: str, args: typing.Union[list, tuple]) -> str:
args_str = ' '.join(
a if isinstance(a, str)
else self._print(a)
for a in args
)
return f'({op} {args_str})'
def _print_Function(self, e):
if e in self._known_functions:
op = self._known_functions[e]
elif type(e) in self._known_functions:
op = self._known_functions[type(e)]
elif type(type(e)) == UndefinedFunction:
op = e.name
elif isinstance(e, AppliedBinaryRelation) and e.function in self._known_functions:
op = self._known_functions[e.function]
return self._s_expr(op, e.arguments)
else:
op = self._known_functions[e] # throw KeyError
return self._s_expr(op, e.args)
def _print_Relational(self, e: Relational):
return self._print_Function(e)
def _print_BooleanFunction(self, e: BooleanFunction):
return self._print_Function(e)
def _print_Expr(self, e: Expr):
return self._print_Function(e)
def _print_Unequality(self, e: Unequality):
if type(e) in self._known_functions:
return self._print_Relational(e) # default
else:
eq_op = self._known_functions[Equality]
not_op = self._known_functions[Not]
return self._s_expr(not_op, [self._s_expr(eq_op, e.args)])
def _print_Piecewise(self, e: Piecewise):
def _print_Piecewise_recursive(args: typing.Union[list, tuple]):
e, c = args[0]
if len(args) == 1:
assert (c is True) or isinstance(c, BooleanTrue)
return self._print(e)
else:
ite = self._known_functions[ITE]
return self._s_expr(ite, [
c, e, _print_Piecewise_recursive(args[1:])
])
return _print_Piecewise_recursive(e.args)
def _print_Interval(self, e: Interval):
if e.start.is_infinite and e.end.is_infinite:
return ''
elif e.start.is_infinite != e.end.is_infinite:
raise ValueError(f'One-sided intervals (`{e}`) are not supported in SMT.')
else:
return f'[{e.start}, {e.end}]'
def _print_AppliedPredicate(self, e: AppliedPredicate):
if e.function == Q.positive:
rel = Q.gt(e.arguments[0],0)
elif e.function == Q.negative:
rel = Q.lt(e.arguments[0], 0)
elif e.function == Q.zero:
rel = Q.eq(e.arguments[0], 0)
elif e.function == Q.nonpositive:
rel = Q.le(e.arguments[0], 0)
elif e.function == Q.nonnegative:
rel = Q.ge(e.arguments[0], 0)
elif e.function == Q.nonzero:
rel = Q.ne(e.arguments[0], 0)
else:
raise ValueError(f"Predicate (`{e}`) is not handled.")
return self._print_AppliedBinaryRelation(rel)
def _print_AppliedBinaryRelation(self, e: AppliedPredicate):
if e.function == Q.ne:
return self._print_Unequality(Unequality(*e.arguments))
else:
return self._print_Function(e)
# todo: Sympy does not support quantifiers yet as of 2022, but quantifiers can be handy in SMT.
# For now, users can extend this class and build in their own quantifier support.
# See `test_quantifier_extensions()` in test_smtlib.py for an example of how this might look.
# def _print_ForAll(self, e: ForAll):
# return self._s('forall', [
# self._s('', [
# self._s(sym.name, [self._type_name(sym), Interval(start, end)])
# for sym, start, end in e.limits
# ]),
# e.function
# ])
def _print_BooleanTrue(self, x: BooleanTrue):
return 'true'
def _print_BooleanFalse(self, x: BooleanFalse):
return 'false'
def _print_Float(self, x: Float):
dps = prec_to_dps(x._prec)
str_real = mlib_to_str(x._mpf_, dps, strip_zeros=True, min_fixed=None, max_fixed=None)
if 'e' in str_real:
(mant, exp) = str_real.split('e')
if exp[0] == '+':
exp = exp[1:]
mul = self._known_functions[Mul]
pow = self._known_functions[Pow]
return r"(%s %s (%s 10 %s))" % (mul, mant, pow, exp)
elif str_real in ["+inf", "-inf"]:
raise ValueError("Infinite values are not supported in SMT.")
else:
return str_real
def _print_float(self, x: float):
return self._print(Float(x))
def _print_Rational(self, x: Rational):
return self._s_expr('/', [x.p, x.q])
def _print_Integer(self, x: Integer):
assert x.q == 1
return str(x.p)
def _print_int(self, x: int):
return str(x)
def _print_Symbol(self, x: Symbol):
assert self._is_legal_name(x.name)
return x.name
def _print_NumberSymbol(self, x):
name = self._known_constants.get(x)
if name:
return name
else:
f = x.evalf(self._precision) if self._precision else x.evalf()
return self._print_Float(f)
def _print_UndefinedFunction(self, x):
assert self._is_legal_name(x.name)
return x.name
def _print_Exp1(self, x):
return (
self._print_Function(exp(1, evaluate=False))
if exp in self._known_functions else
self._print_NumberSymbol(x)
)
def emptyPrinter(self, expr):
raise NotImplementedError(f'Cannot convert `{repr(expr)}` of type `{type(expr)}` to SMT.')
def smtlib_code(
expr,
auto_assert=True, auto_declare=True,
precision=None,
symbol_table=None,
known_types=None, known_constants=None, known_functions=None,
prefix_expressions=None, suffix_expressions=None,
log_warn=None
):
r"""Converts ``expr`` to a string of smtlib code.
Parameters
==========
expr : Expr | List[Expr]
A SymPy expression or system to be converted.
auto_assert : bool, optional
If false, do not modify expr and produce only the S-Expression equivalent of expr.
If true, assume expr is a system and assert each boolean element.
auto_declare : bool, optional
If false, do not produce declarations for the symbols used in expr.
If true, prepend all necessary declarations for variables used in expr based on symbol_table.
precision : integer, optional
The ``evalf(..)`` precision for numbers such as pi.
symbol_table : dict, optional
A dictionary where keys are ``Symbol`` or ``Function`` instances and values are their Python type i.e. ``bool``, ``int``, ``float``, or ``Callable[...]``.
If incomplete, an attempt will be made to infer types from ``expr``.
known_types: dict, optional
A dictionary where keys are ``bool``, ``int``, ``float`` etc. and values are their corresponding SMT type names.
If not given, a partial listing compatible with several solvers will be used.
known_functions : dict, optional
A dictionary where keys are ``Function``, ``Relational``, ``BooleanFunction``, or ``Expr`` instances and values are their SMT string representations.
If not given, a partial listing optimized for dReal solver (but compatible with others) will be used.
known_constants: dict, optional
A dictionary where keys are ``NumberSymbol`` instances and values are their SMT variable names.
When using this feature, extra caution must be taken to avoid naming collisions between user symbols and listed constants.
If not given, constants will be expanded inline i.e. ``3.14159`` instead of ``MY_SMT_VARIABLE_FOR_PI``.
prefix_expressions: list, optional
A list of lists of ``str`` and/or expressions to convert into SMTLib and prefix to the output.
suffix_expressions: list, optional
A list of lists of ``str`` and/or expressions to convert into SMTLib and postfix to the output.
log_warn: lambda function, optional
A function to record all warnings during potentially risky operations.
Soundness is a core value in SMT solving, so it is good to log all assumptions made.
Examples
========
>>> from sympy import smtlib_code, symbols, sin, Eq
>>> x = symbols('x')
>>> smtlib_code(sin(x).series(x).removeO(), log_warn=print)
Could not infer type of `x`. Defaulting to float.
Non-Boolean expression `x**5/120 - x**3/6 + x` will not be asserted. Converting to SMTLib verbatim.
'(declare-const x Real)\n(+ x (* (/ -1 6) (pow x 3)) (* (/ 1 120) (pow x 5)))'
>>> from sympy import Rational
>>> x, y, tau = symbols("x, y, tau")
>>> smtlib_code((2*tau)**Rational(7, 2), log_warn=print)
Could not infer type of `tau`. Defaulting to float.
Non-Boolean expression `8*sqrt(2)*tau**(7/2)` will not be asserted. Converting to SMTLib verbatim.
'(declare-const tau Real)\n(* 8 (pow 2 (/ 1 2)) (pow tau (/ 7 2)))'
``Piecewise`` expressions are implemented with ``ite`` expressions by default.
Note that if the ``Piecewise`` lacks a default term, represented by
``(expr, True)`` then an error will be thrown. This is to prevent
generating an expression that may not evaluate to anything.
>>> from sympy import Piecewise
>>> pw = Piecewise((x + 1, x > 0), (x, True))
>>> smtlib_code(Eq(pw, 3), symbol_table={x: float}, log_warn=print)
'(declare-const x Real)\n(assert (= (ite (> x 0) (+ 1 x) x) 3))'
Custom printing can be defined for certain types by passing a dictionary of
PythonType : "SMT Name" to the ``known_types``, ``known_constants``, and ``known_functions`` kwargs.
>>> from typing import Callable
>>> from sympy import Function, Add
>>> f = Function('f')
>>> g = Function('g')
>>> smt_builtin_funcs = { # functions our SMT solver will understand
... f: "existing_smtlib_fcn",
... Add: "sum",
... }
>>> user_def_funcs = { # functions defined by the user must have their types specified explicitly
... g: Callable[[int], float],
... }
>>> smtlib_code(f(x) + g(x), symbol_table=user_def_funcs, known_functions=smt_builtin_funcs, log_warn=print)
Non-Boolean expression `f(x) + g(x)` will not be asserted. Converting to SMTLib verbatim.
'(declare-const x Int)\n(declare-fun g (Int) Real)\n(sum (existing_smtlib_fcn x) (g x))'
"""
log_warn = log_warn or (lambda _: None)
if not isinstance(expr, list): expr = [expr]
expr = [
sympy.sympify(_, strict=True, evaluate=False, convert_xor=False)
for _ in expr
]
if not symbol_table: symbol_table = {}
symbol_table = _auto_infer_smtlib_types(
*expr, symbol_table=symbol_table
)
# See [FALLBACK RULES]
# Need SMTLibPrinter to populate known_functions and known_constants first.
settings = {}
if precision: settings['precision'] = precision
del precision
if known_types: settings['known_types'] = known_types
del known_types
if known_functions: settings['known_functions'] = known_functions
del known_functions
if known_constants: settings['known_constants'] = known_constants
del known_constants
if not prefix_expressions: prefix_expressions = []
if not suffix_expressions: suffix_expressions = []
p = SMTLibPrinter(settings, symbol_table)
del symbol_table
# [FALLBACK RULES]
for e in expr:
for sym in e.atoms(Symbol, Function):
if (
sym.is_Symbol and
sym not in p._known_constants and
sym not in p.symbol_table
):
log_warn(f"Could not infer type of `{sym}`. Defaulting to float.")
p.symbol_table[sym] = float
if (
sym.is_Function and
type(sym) not in p._known_functions and
type(sym) not in p.symbol_table and
not sym.is_Piecewise
): raise TypeError(
f"Unknown type of undefined function `{sym}`. "
f"Must be mapped to ``str`` in known_functions or mapped to ``Callable[..]`` in symbol_table."
)
declarations = []
if auto_declare:
constants = {sym.name: sym for e in expr for sym in e.free_symbols
if sym not in p._known_constants}
functions = {fnc.name: fnc for e in expr for fnc in e.atoms(Function)
if type(fnc) not in p._known_functions and not fnc.is_Piecewise}
declarations = \
[
_auto_declare_smtlib(sym, p, log_warn)
for sym in constants.values()
] + [
_auto_declare_smtlib(fnc, p, log_warn)
for fnc in functions.values()
]
declarations = [decl for decl in declarations if decl]
if auto_assert:
expr = [_auto_assert_smtlib(e, p, log_warn) for e in expr]
# return SMTLibPrinter().doprint(expr)
return '\n'.join([
# ';; PREFIX EXPRESSIONS',
*[
e if isinstance(e, str) else p.doprint(e)
for e in prefix_expressions
],
# ';; DECLARATIONS',
*sorted(e for e in declarations),
# ';; EXPRESSIONS',
*[
e if isinstance(e, str) else p.doprint(e)
for e in expr
],
# ';; SUFFIX EXPRESSIONS',
*[
e if isinstance(e, str) else p.doprint(e)
for e in suffix_expressions
],
])
def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter, log_warn: typing.Callable[[str], None]):
if sym.is_Symbol:
type_signature = p.symbol_table[sym]
assert isinstance(type_signature, type)
type_signature = p._known_types[type_signature]
return p._s_expr('declare-const', [sym, type_signature])
elif sym.is_Function:
type_signature = p.symbol_table[type(sym)]
assert callable(type_signature)
type_signature = [p._known_types[_] for _ in type_signature.__args__]
assert len(type_signature) > 0
params_signature = f"({' '.join(type_signature[:-1])})"
return_signature = type_signature[-1]
return p._s_expr('declare-fun', [type(sym), params_signature, return_signature])
else:
log_warn(f"Non-Symbol/Function `{sym}` will not be declared.")
return None
def _auto_assert_smtlib(e: Expr, p: SMTLibPrinter, log_warn: typing.Callable[[str], None]):
if isinstance(e, Boolean) or (
e in p.symbol_table and p.symbol_table[e] == bool
) or (
e.is_Function and
type(e) in p.symbol_table and
p.symbol_table[type(e)].__args__[-1] == bool
):
return p._s_expr('assert', [e])
else:
log_warn(f"Non-Boolean expression `{e}` will not be asserted. Converting to SMTLib verbatim.")
return e
def _auto_infer_smtlib_types(
*exprs: Basic,
symbol_table: typing.Optional[dict] = None
) -> dict:
# [TYPE INFERENCE RULES]
# X is alone in an expr => X is bool
# X in BooleanFunction.args => X is bool
# X matches to a bool param of a symbol_table function => X is bool
# X matches to an int param of a symbol_table function => X is int
# X.is_integer => X is int
# X == Y, where X is T => Y is T
# [FALLBACK RULES]
# see _auto_declare_smtlib(..)
# X is not bool and X is not int and X is Symbol => X is float
# else (e.g. X is Function) => error. must be specified explicitly.
_symbols = dict(symbol_table) if symbol_table else {}
def safe_update(syms: set, inf):
for s in syms:
assert s.is_Symbol
if (old_type := _symbols.setdefault(s, inf)) != inf:
raise TypeError(f"Could not infer type of `{s}`. Apparently both `{old_type}` and `{inf}`?")
# EXPLICIT TYPES
safe_update({
e
for e in exprs
if e.is_Symbol
}, bool)
safe_update({
symbol
for e in exprs
for boolfunc in e.atoms(BooleanFunction)
for symbol in boolfunc.args
if symbol.is_Symbol
}, bool)
safe_update({
symbol
for e in exprs
for boolfunc in e.atoms(Function)
if type(boolfunc) in _symbols
for symbol, param in zip(boolfunc.args, _symbols[type(boolfunc)].__args__)
if symbol.is_Symbol and param == bool
}, bool)
safe_update({
symbol
for e in exprs
for intfunc in e.atoms(Function)
if type(intfunc) in _symbols
for symbol, param in zip(intfunc.args, _symbols[type(intfunc)].__args__)
if symbol.is_Symbol and param == int
}, int)
safe_update({
symbol
for e in exprs
for symbol in e.atoms(Symbol)
if symbol.is_integer
}, int)
safe_update({
symbol
for e in exprs
for symbol in e.atoms(Symbol)
if symbol.is_real and not symbol.is_integer
}, float)
# EQUALITY RELATION RULE
rels = [rel for expr in exprs for rel in expr.atoms(Equality)]
rels = [
(rel.lhs, rel.rhs) for rel in rels if rel.lhs.is_Symbol
] + [
(rel.rhs, rel.lhs) for rel in rels if rel.rhs.is_Symbol
]
for infer, reltd in rels:
inference = (
_symbols[infer] if infer in _symbols else
_symbols[reltd] if reltd in _symbols else
_symbols[type(reltd)].__args__[-1]
if reltd.is_Function and type(reltd) in _symbols else
bool if reltd.is_Boolean else
int if reltd.is_integer or reltd.is_Integer else
float if reltd.is_real else
None
)
if inference: safe_update({infer}, inference)
return _symbols

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,366 @@
from sympy.core.containers import Tuple
from sympy.core.singleton import S
from sympy.core.symbol import Symbol
from sympy.core.sympify import SympifyError
from types import FunctionType
class TableForm:
r"""
Create a nice table representation of data.
Examples
========
>>> from sympy import TableForm
>>> t = TableForm([[5, 7], [4, 2], [10, 3]])
>>> print(t)
5 7
4 2
10 3
You can use the SymPy's printing system to produce tables in any
format (ascii, latex, html, ...).
>>> print(t.as_latex())
\begin{tabular}{l l}
$5$ & $7$ \\
$4$ & $2$ \\
$10$ & $3$ \\
\end{tabular}
"""
def __init__(self, data, **kwarg):
"""
Creates a TableForm.
Parameters:
data ...
2D data to be put into the table; data can be
given as a Matrix
headings ...
gives the labels for rows and columns:
Can be a single argument that applies to both
dimensions:
- None ... no labels
- "automatic" ... labels are 1, 2, 3, ...
Can be a list of labels for rows and columns:
The labels for each dimension can be given
as None, "automatic", or [l1, l2, ...] e.g.
["automatic", None] will number the rows
[default: None]
alignments ...
alignment of the columns with:
- "left" or "<"
- "center" or "^"
- "right" or ">"
When given as a single value, the value is used for
all columns. The row headings (if given) will be
right justified unless an explicit alignment is
given for it and all other columns.
[default: "left"]
formats ...
a list of format strings or functions that accept
3 arguments (entry, row number, col number) and
return a string for the table entry. (If a function
returns None then the _print method will be used.)
wipe_zeros ...
Do not show zeros in the table.
[default: True]
pad ...
the string to use to indicate a missing value (e.g.
elements that are None or those that are missing
from the end of a row (i.e. any row that is shorter
than the rest is assumed to have missing values).
When None, nothing will be shown for values that
are missing from the end of a row; values that are
None, however, will be shown.
[default: None]
Examples
========
>>> from sympy import TableForm, Symbol
>>> TableForm([[5, 7], [4, 2], [10, 3]])
5 7
4 2
10 3
>>> TableForm([list('.'*i) for i in range(1, 4)], headings='automatic')
| 1 2 3
---------
1 | .
2 | . .
3 | . . .
>>> TableForm([[Symbol('.'*(j if not i%2 else 1)) for i in range(3)]
... for j in range(4)], alignments='rcl')
.
. . .
.. . ..
... . ...
"""
from sympy.matrices.dense import Matrix
# We only support 2D data. Check the consistency:
if isinstance(data, Matrix):
data = data.tolist()
_h = len(data)
# fill out any short lines
pad = kwarg.get('pad', None)
ok_None = False
if pad is None:
pad = " "
ok_None = True
pad = Symbol(pad)
_w = max(len(line) for line in data)
for i, line in enumerate(data):
if len(line) != _w:
line.extend([pad]*(_w - len(line)))
for j, lj in enumerate(line):
if lj is None:
if not ok_None:
lj = pad
else:
try:
lj = S(lj)
except SympifyError:
lj = Symbol(str(lj))
line[j] = lj
data[i] = line
_lines = Tuple(*[Tuple(*d) for d in data])
headings = kwarg.get("headings", [None, None])
if headings == "automatic":
_headings = [range(1, _h + 1), range(1, _w + 1)]
else:
h1, h2 = headings
if h1 == "automatic":
h1 = range(1, _h + 1)
if h2 == "automatic":
h2 = range(1, _w + 1)
_headings = [h1, h2]
allow = ('l', 'r', 'c')
alignments = kwarg.get("alignments", "l")
def _std_align(a):
a = a.strip().lower()
if len(a) > 1:
return {'left': 'l', 'right': 'r', 'center': 'c'}.get(a, a)
else:
return {'<': 'l', '>': 'r', '^': 'c'}.get(a, a)
std_align = _std_align(alignments)
if std_align in allow:
_alignments = [std_align]*_w
else:
_alignments = []
for a in alignments:
std_align = _std_align(a)
_alignments.append(std_align)
if std_align not in ('l', 'r', 'c'):
raise ValueError('alignment "%s" unrecognized' %
alignments)
if _headings[0] and len(_alignments) == _w + 1:
_head_align = _alignments[0]
_alignments = _alignments[1:]
else:
_head_align = 'r'
if len(_alignments) != _w:
raise ValueError(
'wrong number of alignments: expected %s but got %s' %
(_w, len(_alignments)))
_column_formats = kwarg.get("formats", [None]*_w)
_wipe_zeros = kwarg.get("wipe_zeros", True)
self._w = _w
self._h = _h
self._lines = _lines
self._headings = _headings
self._head_align = _head_align
self._alignments = _alignments
self._column_formats = _column_formats
self._wipe_zeros = _wipe_zeros
def __repr__(self):
from .str import sstr
return sstr(self, order=None)
def __str__(self):
from .str import sstr
return sstr(self, order=None)
def as_matrix(self):
"""Returns the data of the table in Matrix form.
Examples
========
>>> from sympy import TableForm
>>> t = TableForm([[5, 7], [4, 2], [10, 3]], headings='automatic')
>>> t
| 1 2
--------
1 | 5 7
2 | 4 2
3 | 10 3
>>> t.as_matrix()
Matrix([
[ 5, 7],
[ 4, 2],
[10, 3]])
"""
from sympy.matrices.dense import Matrix
return Matrix(self._lines)
def as_str(self):
# XXX obsolete ?
return str(self)
def as_latex(self):
from .latex import latex
return latex(self)
def _sympystr(self, p):
"""
Returns the string representation of 'self'.
Examples
========
>>> from sympy import TableForm
>>> t = TableForm([[5, 7], [4, 2], [10, 3]])
>>> s = t.as_str()
"""
column_widths = [0] * self._w
lines = []
for line in self._lines:
new_line = []
for i in range(self._w):
# Format the item somehow if needed:
s = str(line[i])
if self._wipe_zeros and (s == "0"):
s = " "
w = len(s)
if w > column_widths[i]:
column_widths[i] = w
new_line.append(s)
lines.append(new_line)
# Check heading:
if self._headings[0]:
self._headings[0] = [str(x) for x in self._headings[0]]
_head_width = max(len(x) for x in self._headings[0])
if self._headings[1]:
new_line = []
for i in range(self._w):
# Format the item somehow if needed:
s = str(self._headings[1][i])
w = len(s)
if w > column_widths[i]:
column_widths[i] = w
new_line.append(s)
self._headings[1] = new_line
format_str = []
def _align(align, w):
return '%%%s%ss' % (
("-" if align == "l" else ""),
str(w))
format_str = [_align(align, w) for align, w in
zip(self._alignments, column_widths)]
if self._headings[0]:
format_str.insert(0, _align(self._head_align, _head_width))
format_str.insert(1, '|')
format_str = ' '.join(format_str) + '\n'
s = []
if self._headings[1]:
d = self._headings[1]
if self._headings[0]:
d = [""] + d
first_line = format_str % tuple(d)
s.append(first_line)
s.append("-" * (len(first_line) - 1) + "\n")
for i, line in enumerate(lines):
d = [l if self._alignments[j] != 'c' else
l.center(column_widths[j]) for j, l in enumerate(line)]
if self._headings[0]:
l = self._headings[0][i]
l = (l if self._head_align != 'c' else
l.center(_head_width))
d = [l] + d
s.append(format_str % tuple(d))
return ''.join(s)[:-1] # don't include trailing newline
def _latex(self, printer):
"""
Returns the string representation of 'self'.
"""
# Check heading:
if self._headings[1]:
new_line = []
for i in range(self._w):
# Format the item somehow if needed:
new_line.append(str(self._headings[1][i]))
self._headings[1] = new_line
alignments = []
if self._headings[0]:
self._headings[0] = [str(x) for x in self._headings[0]]
alignments = [self._head_align]
alignments.extend(self._alignments)
s = r"\begin{tabular}{" + " ".join(alignments) + "}\n"
if self._headings[1]:
d = self._headings[1]
if self._headings[0]:
d = [""] + d
first_line = " & ".join(d) + r" \\" + "\n"
s += first_line
s += r"\hline" + "\n"
for i, line in enumerate(self._lines):
d = []
for j, x in enumerate(line):
if self._wipe_zeros and (x in (0, "0")):
d.append(" ")
continue
f = self._column_formats[j]
if f:
if isinstance(f, FunctionType):
v = f(x, i, j)
if v is None:
v = printer._print(x)
else:
v = f % x
d.append(v)
else:
v = printer._print(x)
d.append("$%s$" % v)
if self._headings[0]:
d = [self._headings[0][i]] + d
s += " & ".join(d) + r" \\" + "\n"
s += r"\end{tabular}"
return s

View File

@ -0,0 +1,216 @@
from sympy.external.importtools import version_tuple
from collections.abc import Iterable
from sympy.core.mul import Mul
from sympy.core.singleton import S
from sympy.codegen.cfunctions import Sqrt
from sympy.external import import_module
from sympy.printing.precedence import PRECEDENCE
from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter
import sympy
tensorflow = import_module('tensorflow')
class TensorflowPrinter(ArrayPrinter, AbstractPythonCodePrinter):
"""
Tensorflow printer which handles vectorized piecewise functions,
logical operators, max/min, and relational operators.
"""
printmethod = "_tensorflowcode"
mapping = {
sympy.Abs: "tensorflow.math.abs",
sympy.sign: "tensorflow.math.sign",
# XXX May raise error for ints.
sympy.ceiling: "tensorflow.math.ceil",
sympy.floor: "tensorflow.math.floor",
sympy.log: "tensorflow.math.log",
sympy.exp: "tensorflow.math.exp",
Sqrt: "tensorflow.math.sqrt",
sympy.cos: "tensorflow.math.cos",
sympy.acos: "tensorflow.math.acos",
sympy.sin: "tensorflow.math.sin",
sympy.asin: "tensorflow.math.asin",
sympy.tan: "tensorflow.math.tan",
sympy.atan: "tensorflow.math.atan",
sympy.atan2: "tensorflow.math.atan2",
# XXX Also may give NaN for complex results.
sympy.cosh: "tensorflow.math.cosh",
sympy.acosh: "tensorflow.math.acosh",
sympy.sinh: "tensorflow.math.sinh",
sympy.asinh: "tensorflow.math.asinh",
sympy.tanh: "tensorflow.math.tanh",
sympy.atanh: "tensorflow.math.atanh",
sympy.re: "tensorflow.math.real",
sympy.im: "tensorflow.math.imag",
sympy.arg: "tensorflow.math.angle",
# XXX May raise error for ints and complexes
sympy.erf: "tensorflow.math.erf",
sympy.loggamma: "tensorflow.math.lgamma",
sympy.Eq: "tensorflow.math.equal",
sympy.Ne: "tensorflow.math.not_equal",
sympy.StrictGreaterThan: "tensorflow.math.greater",
sympy.StrictLessThan: "tensorflow.math.less",
sympy.LessThan: "tensorflow.math.less_equal",
sympy.GreaterThan: "tensorflow.math.greater_equal",
sympy.And: "tensorflow.math.logical_and",
sympy.Or: "tensorflow.math.logical_or",
sympy.Not: "tensorflow.math.logical_not",
sympy.Max: "tensorflow.math.maximum",
sympy.Min: "tensorflow.math.minimum",
# Matrices
sympy.MatAdd: "tensorflow.math.add",
sympy.HadamardProduct: "tensorflow.math.multiply",
sympy.Trace: "tensorflow.linalg.trace",
# XXX May raise error for integer matrices.
sympy.Determinant : "tensorflow.linalg.det",
}
_default_settings = dict(
AbstractPythonCodePrinter._default_settings,
tensorflow_version=None
)
def __init__(self, settings=None):
super().__init__(settings)
version = self._settings['tensorflow_version']
if version is None and tensorflow:
version = tensorflow.__version__
self.tensorflow_version = version
def _print_Function(self, expr):
op = self.mapping.get(type(expr), None)
if op is None:
return super()._print_Basic(expr)
children = [self._print(arg) for arg in expr.args]
if len(children) == 1:
return "%s(%s)" % (
self._module_format(op),
children[0]
)
else:
return self._expand_fold_binary_op(op, children)
_print_Expr = _print_Function
_print_Application = _print_Function
_print_MatrixExpr = _print_Function
# TODO: a better class structure would avoid this mess:
_print_Relational = _print_Function
_print_Not = _print_Function
_print_And = _print_Function
_print_Or = _print_Function
_print_HadamardProduct = _print_Function
_print_Trace = _print_Function
_print_Determinant = _print_Function
def _print_Inverse(self, expr):
op = self._module_format('tensorflow.linalg.inv')
return "{}({})".format(op, self._print(expr.arg))
def _print_Transpose(self, expr):
version = self.tensorflow_version
if version and version_tuple(version) < version_tuple('1.14'):
op = self._module_format('tensorflow.matrix_transpose')
else:
op = self._module_format('tensorflow.linalg.matrix_transpose')
return "{}({})".format(op, self._print(expr.arg))
def _print_Derivative(self, expr):
variables = expr.variables
if any(isinstance(i, Iterable) for i in variables):
raise NotImplementedError("derivation by multiple variables is not supported")
def unfold(expr, args):
if not args:
return self._print(expr)
return "%s(%s, %s)[0]" % (
self._module_format("tensorflow.gradients"),
unfold(expr, args[:-1]),
self._print(args[-1]),
)
return unfold(expr.expr, variables)
def _print_Piecewise(self, expr):
version = self.tensorflow_version
if version and version_tuple(version) < version_tuple('1.0'):
tensorflow_piecewise = "tensorflow.select"
else:
tensorflow_piecewise = "tensorflow.where"
from sympy.functions.elementary.piecewise import Piecewise
e, cond = expr.args[0].args
if len(expr.args) == 1:
return '{}({}, {}, {})'.format(
self._module_format(tensorflow_piecewise),
self._print(cond),
self._print(e),
0)
return '{}({}, {}, {})'.format(
self._module_format(tensorflow_piecewise),
self._print(cond),
self._print(e),
self._print(Piecewise(*expr.args[1:])))
def _print_Pow(self, expr):
# XXX May raise error for
# int**float or int**complex or float**complex
base, exp = expr.args
if expr.exp == S.Half:
return "{}({})".format(
self._module_format("tensorflow.math.sqrt"), self._print(base))
return "{}({}, {})".format(
self._module_format("tensorflow.math.pow"),
self._print(base), self._print(exp))
def _print_MatrixBase(self, expr):
tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant"
data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]"
return "%s(%s)" % (
self._module_format(tensorflow_f),
data,
)
def _print_MatMul(self, expr):
from sympy.matrices.expressions import MatrixExpr
mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)]
args = [arg for arg in expr.args if arg not in mat_args]
if args:
return "%s*%s" % (
self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]),
self._expand_fold_binary_op(
"tensorflow.linalg.matmul", mat_args)
)
else:
return self._expand_fold_binary_op(
"tensorflow.linalg.matmul", mat_args)
def _print_MatPow(self, expr):
return self._expand_fold_binary_op(
"tensorflow.linalg.matmul", [expr.base]*expr.exp)
def _print_CodeBlock(self, expr):
# TODO: is this necessary?
ret = []
for subexpr in expr.args:
ret.append(self._print(subexpr))
return "\n".join(ret)
_module = "tensorflow"
_einsum = "linalg.einsum"
_add = "math.add"
_transpose = "transpose"
_ones = "ones"
_zeros = "zeros"
def tensorflow_code(expr, **settings):
printer = TensorflowPrinter(settings)
return printer.doprint(expr)

Some files were not shown because too many files have changed in this diff Show More