Skip to content

Commit

Permalink
Merge pull request #168 from bluescarni/pr/cfunc_updates
Browse files Browse the repository at this point in the history
cfunc API updates
  • Loading branch information
bluescarni committed Feb 19, 2024
2 parents d11689e + 240c6be commit 95c26d3
Show file tree
Hide file tree
Showing 19 changed files with 972 additions and 1,059 deletions.
8 changes: 8 additions & 0 deletions doc/api_exsys.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ Functions
make_vars
diff_tensors

Attributes
----------

.. autosummary::
:toctree: autosummary_generated

par

Enums
-----

Expand Down
31 changes: 24 additions & 7 deletions doc/breaking_changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,26 @@ Breaking changes

heyoka.py 4 includes several backwards-incompatible changes.

API/behaviour changes
~~~~~~~~~~~~~~~~~~~~~
Changes to compiled functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The function to create :ref:`compiled functions <cfunc_tut>` has been renamed from
``make_cfunc()`` to simply ``cfunc()``.

Compiled functions have also gained the ability to use multiple
threads of execution during batched evaluations. As a consequence, compiled functions
now require contiguous NumPy arrays to be passed as input/output arguments (whereas
in previous versions compiled functions would work also with non-contiguous
arrays). The NumPy function :py:func:`numpy.ascontiguousarray()` can be used to turn
non-contiguous arrays into contiguous arrays.

Finally, compiled functions are now stricter with respect to type conversions: if a NumPy
array with the wrong datatype is passed as an input/output argument, an error will be raised
(whereas previously heyoka.py would convert the array to the correct datatype on-the-fly).
The NumPy method :py:meth:`numpy.ndarray.astype()` can be used for datatype conversions.

A more explicit API
^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~

Several functions and classes have been changed to explicitly require
the user to pass a list of variables in input. The previous behaviour, where
Expand All @@ -31,23 +46,25 @@ The affected APIs include:
The tutorials and the documentation have been updated accordingly.

Changes to :py:func:`~heyoka.make_vars()`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The :py:func:`~heyoka.make_vars()` function now returns a single expression (rather than a list of expressions)
if a single argument is passed in input. This means that code such as

.. code-block:: python
x, = make_vars("x")
y = make_vars("y")[0]
needs to be rewritten like this:

.. code-block:: python
x = make_vars("x")
y = make_vars("y")
Terminal events callbacks
^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~

The second argument in the signature of callbacks for terminal events, a ``bool`` conventionally
called ``mr``, has been removed. This flag was meant to signal the possibility of multiple roots
Expand All @@ -58,7 +75,7 @@ Adapting existing code for this API change is straightforward: you just have to
from the signature of a terminal event callback.

Step callbacks and ``propagate_*()``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The step callbacks that can (optionally) be passed to the ``propagate_*()`` methods of the
adaptive integrators are now part of the return value. Specifically:
Expand All @@ -82,7 +99,7 @@ a matter of:
rather than a single value.

Changes to ``propagate_grid()``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``propagate_grid()`` methods of the adaptive integrators now require the first element of the
time grid to be equal to the current integrator time. Previously, in case of a difference between the
Expand Down
12 changes: 12 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Changelog
New
~~~

- Compiled functions now support multithreaded parallelisation
for batched evaluations
(`#168 <https://github.com/bluescarni/heyoka.py/pull/168>`__).
- Add new example on gravity-gradient stabilisation
(`#159 <https://github.com/bluescarni/heyoka.py/pull/159>`__).
- Add support for Lagrangian and Hamiltonian mechanics
Expand All @@ -25,6 +28,15 @@ New
Changes
~~~~~~~

- **BREAKING**: the function to construct compiled functions
has been renamed from ``make_cfunc()`` to ``cfunc()``
(`#168 <https://github.com/bluescarni/heyoka.py/pull/168>`__).
This is a :ref:`breaking change <bchanges_4_0_0>`.
- **BREAKING**: compiled functions now require contiguous arrays
as input/output arguments. The compiled functions API is also now
more restrictive with respect to on-the-fly type conversions
(`#168 <https://github.com/bluescarni/heyoka.py/pull/168>`__).
These are :ref:`breaking changes <bchanges_4_0_0>`.
- **BREAKING**: it is now mandatory to supply a list of differentiation
arguments to :func:`~heyoka.diff_tensors()`
(`#164 <https://github.com/bluescarni/heyoka.py/pull/164>`__).
Expand Down
4 changes: 3 additions & 1 deletion doc/notebooks/ODEs with parameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(runtime_param)=\n",
"\n",
"ODEs with parameters\n",
"==================\n",
"====================\n",
"\n",
"The values of numerical constants in heyoka.py can either be specified when constructing an ODE system, or they can be loaded at a later stage when the ODE system is being integrated. The latter type of numerical constant is known as a *parameter*.\n",
"\n",
Expand Down
2 changes: 2 additions & 0 deletions doc/notebooks/The expression system.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"id": "unable-biodiversity",
"metadata": {},
"source": [
"(ex_system)=\n",
"\n",
"The expression system\n",
"===================\n",
"\n",
Expand Down
74 changes: 49 additions & 25 deletions doc/notebooks/compiled_functions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"array([-24.])"
]
},
"execution_count": 8,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -157,7 +157,7 @@
"array([-24.])"
]
},
"execution_count": 10,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -246,9 +246,20 @@
"id": "4bea5b7b-ab72-47cf-a2dd-7f944248992b",
"metadata": {},
"source": [
"```{warning}\n",
"\n",
"The ability to pass lists and other iterables as input/output arguments\n",
"to a compiled function is offered as a convenience, but it incurs into a runtime\n",
"cost as compiled functions need to convert the iterables into NumPy arrays\n",
"on-the-fly before performing the evaluation.\n",
"\n",
"For optimal performance, consider passing NumPy arrays to the call operator\n",
"of compiled functions in order to avoid this hidden cost.\n",
"```\n",
"\n",
"## Functions with parameters\n",
"\n",
"It the compiled function references external parameters, the parameters array will have to be supplied during evaluation via the ``pars`` keyword argument:"
"It the compiled function references [external parameters](<./ODEs with parameters.ipynb>), the parameters array will have to be supplied during evaluation via the ``pars`` keyword argument:"
]
},
{
Expand Down Expand Up @@ -318,7 +329,7 @@
"cf_tm = hy.cfunc([sym_func_tm], [x, y])\n",
"\n",
"# Evaluate for x=1, y=5 and time=6.\n",
"cf_tm([1,5], time=6)"
"cf_tm([1,5], time=6.)"
]
},
{
Expand Down Expand Up @@ -473,7 +484,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2.12 s ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"2.16 s ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -533,14 +544,35 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 20,
"id": "558ee9ab-d6e8-4d48-8556-46966eef9d30",
"metadata": {},
"outputs": [],
"source": [
"Ham_cf = hy.cfunc([Ham_sym], vars=[px,py,pz,x,y,z])"
]
},
{
"cell_type": "markdown",
"id": "dde49007-40f5-4623-8d09-4f77e1baf5e2",
"metadata": {},
"source": [
"heyoka.py's compiled functions support multithreaded parallelisation for batched evaluations. However, for this simple test, we will be disabling multithreaded parallelisation:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "ccb9d544-2b16-492b-be62-969b0f79aaad",
"metadata": {},
"outputs": [],
"source": [
"# Disable parallel batched evaluations by\n",
"# setting the number of threads in use by\n",
"# heyoka.py to 1.\n",
"hy.set_nthreads(1)"
]
},
{
"cell_type": "markdown",
"id": "216ded35-32ae-4a42-9c7a-a5f06b00a952",
Expand All @@ -551,15 +583,15 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 22,
"id": "32131124-52a4-4b92-97fb-3fccd9a77157",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"209 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"213 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -577,21 +609,21 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 27,
"id": "59999b6a-e45a-4673-b239-e8b76a6c9be0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"178 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"179 ms ± 3.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"# Pre-allocate the outputs array.\n",
"outputs = np.zeros((1, nevals))\n",
"outputs = np.empty((1, nevals))\n",
"%timeit Ham_cf(inputs,outputs=outputs)"
]
},
Expand All @@ -604,25 +636,17 @@
"\n",
"### JAX\n",
"\n",
"As a last benchmark, we will be performing the same evaluation with [JAX](https://jax.readthedocs.io/en/latest/index.html). Similarly to heyoka.py, JAX offers the possibility to [JIT compile Python functions](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#using-jit-to-speed-up-functions), so we expect similar performance to heyoka.py. Note that, in order to perform a fair comparison, for the execution of this notebook we [enabled 64-bit floats in JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) and we used JAX's CPU backend [forcing a single thread of execution](https://github.com/google/jax/issues/1539) (JAX by default uses multiple threads of execution, but heyoka.py's compiled functions do not yet support multithreaded execution).\n",
"As a last benchmark, we will be performing the same evaluation with [JAX](https://jax.readthedocs.io/en/latest/index.html). Similarly to heyoka.py, JAX offers the possibility to [JIT compile Python functions](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#using-jit-to-speed-up-functions), so we expect similar performance to heyoka.py. Note that, in order to perform a fair comparison, for the execution of this notebook we [enabled 64-bit floats in JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) and we used JAX's CPU backend [forcing a single thread of execution](https://github.com/google/jax/issues/1539).\n",
"\n",
"Let us see the jax code:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 28,
"id": "e5e3a887-fc7a-4df4-8081-bce7db5db23f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -662,15 +686,15 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 30,
"id": "9b9af294-27a3-44f9-8678-208b30f405b3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"304 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"313 ms ± 1.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -703,7 +727,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 95c26d3

Please sign in to comment.