Skip to content

Commit

Permalink
Implemented conditional statement.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerrit Ansmann committed Jul 23, 2019
1 parent 1e28fb3 commit 82ccf11
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

MOCK_MODULES = [
'numpy', 'numpy.testing', 'numpy.random',
'symengine', 'symengine.printing',
'symengine', 'symengine.printing', 'symengine.lib.symengine_wrapper',
'.'
]
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
Expand Down
23 changes: 18 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,30 @@ Here is an example for imports that make use of this:
Note that while SymEngine’s Python wrapper is sparsely documented, almost everything that is relevant to JiTC*DE behaves analogously to SymPy and the latter’s documentation serves as a documentation for SymEngine as well.
For this reason, JiTC*DE’s documentation also often links to SymPy’s documentation when talking about SymEngine features.

Conditionals
------------

Many dynamics contain a step function, Heaviside function, conditional, or whatever you like to call it.
In the vast majority of cases you cannot naïvely implement this, because discontinuities can lead to all sorts of problems with the integrators.
Most importantly, error estimation and step-size adaption requires a continuous derivative.
Moreover, any Python conditionals will be evaluated during the code generation and not at runtime, which not what you want in this case.

There are two general ways to solve this:

* If your step-wise behaviour depends on time (e.g., an external pulse that is limited in time), integrate up to the point of the step, change `f` or a control parameter, and continue.
Note that for DDEs this may introduce a discontinuity that needs to be dealt with like an initial discontinuity.

* Use a sharp sigmoid instead of the step function.
`jitcxde_common` provides a service function `conditional` which can be used for this purpose and is documented below.

.. autofunction:: symbolic.conditional

Common Mistakes and Questions
-----------------------------

* If you want to use mathematical functions like `sin`, `exp` or `sqrt` you have to use the SymEngine variants.
For example, instead of `math.sin` or `numpy.sin`, you have to use `symengine.sin`.

* If you wish to use step functions to drive the system or similar, the best alternative is usually to use a sharp sigmoid instead.
SymEngine has not implemented SymPy’s `Piecewise` yet, but more importantly discontinuities can cause all sorts of problems with the integrators.
If your step-wise behaviour depends on time (e.g., an external pulse that is limited in time), you can also integrate up to the point of the step, change `f` or a control parameter, and continue.
Note that for DDEs this may introduce a discontinuity that needs to be dealt with like an initial discontinuity.

* If you get unexpected or cryptic errors, please run the respective class’s `check` function and also check that all input has the right format and functions have the right signature.

* If JiTC*DE’s code generation and compilation is too slow or bursts your memory, check:
Expand Down
1 change: 1 addition & 0 deletions jitcxde_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
raise NotImplementedError("Python versions below 3.3 are not supported anymore (or never were). Please upgrade to a newer Python version.")

from ._jitcxde import jitcxde, DEFAULT_COMPILE_ARGS, DEFAULT_LINK_ARGS, MSVC_COMPILE_ARGS, MSVC_LINK_ARGS
from .symbolic import conditional
from .check import checker

try:
Expand Down
25 changes: 25 additions & 0 deletions jitcxde_common/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from symengine.lib.symengine_wrapper import FunctionSymbol
from symengine import sympify, tanh

def is_call(expression,function):
"""
Expand Down Expand Up @@ -80,3 +81,27 @@ def replace_function(expression,function,new_function):
else:
return expression.func(*replaced_args)

def conditional(observable,threshold,value_if,value_else,width=None):
"""
Provides an smoothed and thus integrator-friendly version of a conditional statement. For most purposes, you can imagine this being equivalent to:
.. code-block:: Python
def conditional(observable,threshold,value_if,value_else):
if observable>threshold:
return value_if
else:
return value_else
The import difference is that this is smooth and evaluated at runtime.
`width` controls the steepness of the sigmoidal used to implement this. If not specified, this will be guessed – from the threshold if possible.
"""
if width is None:
if sympify(threshold).is_number and threshold!=0:
width = threshold/100000
else:
width = 1e-5

return value_if+(1+tanh((observable-threshold)/width))/2*(value_else-value_if)

41 changes: 40 additions & 1 deletion tests/test_symbolic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from symengine import Function, Symbol, sin, Integer
from jitcxde_common.symbolic import collect_arguments, count_calls, has_function, replace_function
from jitcxde_common.symbolic import collect_arguments, count_calls, has_function, replace_function, conditional

f = Function("f")
g = Function("g")
Expand Down Expand Up @@ -47,6 +47,45 @@ def test_no_function(self):
self.assertFalse(has_function(expression,f))
self.assertEqual( replace_function(expression,f,g), g(a)+42 )

ε = 1e-2
conditional_test_cases = [
( 41 , 42, 7, 23, 7 ),
( 42-ε, 42, 7, 23, 7 ),
( 42 , 42, 7, 23, 15 ),
( 42+ε, 42, 7, 23, 23 ),
( 43 , 42, 7, 23, 23 ),
( 41 , 42, 23, 7, 23 ),
( 42-ε, 42, 23, 7, 23 ),
( 42 , 42, 23, 7, 15 ),
( 42+ε, 42, 23, 7, 7 ),
( 43 , 42, 23, 7, 7 ),
( -1 , 0, 7, 23, 7 ),
( -ε, 0, 7, 23, 7 ),
( 0 , 0, 7, 23, 15 ),
( +ε, 0, 7, 23, 23 ),
( 1 , 0, 7, 23, 23 ),
]

class TestConditional(unittest.TestCase):
def test_number_input(self):
for obs,thr,v_if,v_else,result in conditional_test_cases:
self.assertAlmostEqual(
float(conditional(obs,thr,v_if,v_else)),
result,
)

def test_symbolic_threshold(self):
for obs,thr,v_if,v_else,result in conditional_test_cases:
self.assertAlmostEqual(
float(conditional(obs,a,v_if,v_else).subs({a:thr})),
result,
)

def test_wide_width(self):
self.assertNotAlmostEqual(
float(conditional(1,0,-1,1,width=1e5)),
1,
)

if __name__ == "__main__":
unittest.main(buffer=True)

0 comments on commit 82ccf11

Please sign in to comment.