Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cfunc API updates #168

Merged
merged 14 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/api_exsys.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ Functions
make_vars
diff_tensors

Attributes
----------

.. autosummary::
:toctree: autosummary_generated

par

Enums
-----

Expand Down
31 changes: 24 additions & 7 deletions doc/breaking_changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,26 @@ Breaking changes

heyoka.py 4 includes several backwards-incompatible changes.

API/behaviour changes
~~~~~~~~~~~~~~~~~~~~~
Changes to compiled functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The function to create :ref:`compiled functions <cfunc_tut>` has been renamed from
``make_cfunc()`` to simply ``cfunc()``.

Compiled functions have also gained the ability to use multiple
threads of execution during batched evaluations. As a consequence, compiled functions
now require contiguous NumPy arrays to be passed as input/output arguments (whereas
in previous versions compiled functions would work also with non-contiguous
arrays). The NumPy function :py:func:`numpy.ascontiguousarray()` can be used to turn
non-contiguous arrays into contiguous arrays.

Finally, compiled functions are now stricter with respect to type conversions: if a NumPy
array with the wrong datatype is passed as an input/output argument, an error will be raised
(whereas previously heyoka.py would convert the array to the correct datatype on-the-fly).
The NumPy method :py:meth:`numpy.ndarray.astype()` can be used for datatype conversions.

A more explicit API
^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~

Several functions and classes have been changed to explicitly require
the user to pass a list of variables in input. The previous behaviour, where
Expand All @@ -31,23 +46,25 @@ The affected APIs include:
The tutorials and the documentation have been updated accordingly.

Changes to :py:func:`~heyoka.make_vars()`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The :py:func:`~heyoka.make_vars()` function now returns a single expression (rather than a list of expressions)
if a single argument is passed in input. This means that code such as

.. code-block:: python

x, = make_vars("x")
y = make_vars("y")[0]

needs to be rewritten like this:

.. code-block:: python

x = make_vars("x")
y = make_vars("y")

Terminal events callbacks
^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~

The second argument in the signature of callbacks for terminal events, a ``bool`` conventionally
called ``mr``, has been removed. This flag was meant to signal the possibility of multiple roots
Expand All @@ -58,7 +75,7 @@ Adapting existing code for this API change is straightforward: you just have to
from the signature of a terminal event callback.

Step callbacks and ``propagate_*()``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The step callbacks that can (optionally) be passed to the ``propagate_*()`` methods of the
adaptive integrators are now part of the return value. Specifically:
Expand All @@ -82,7 +99,7 @@ a matter of:
rather than a single value.

Changes to ``propagate_grid()``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``propagate_grid()`` methods of the adaptive integrators now require the first element of the
time grid to be equal to the current integrator time. Previously, in case of a difference between the
Expand Down
12 changes: 12 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Changelog
New
~~~

- Compiled functions now support multithreaded parallelisation
for batched evaluations
(`#168 <https://github.com/bluescarni/heyoka.py/pull/168>`__).
- Add new example on gravity-gradient stabilisation
(`#159 <https://github.com/bluescarni/heyoka.py/pull/159>`__).
- Add support for Lagrangian and Hamiltonian mechanics
Expand All @@ -25,6 +28,15 @@ New
Changes
~~~~~~~

- **BREAKING**: the function to construct compiled functions
has been renamed from ``make_cfunc()`` to ``cfunc()``
(`#168 <https://github.com/bluescarni/heyoka.py/pull/168>`__).
This is a :ref:`breaking change <bchanges_4_0_0>`.
- **BREAKING**: compiled functions now require contiguous arrays
as input/output arguments. The compiled functions API is also now
more restrictive with respect to on-the-fly type conversions
(`#168 <https://github.com/bluescarni/heyoka.py/pull/168>`__).
These are :ref:`breaking changes <bchanges_4_0_0>`.
- **BREAKING**: it is now mandatory to supply a list of differentiation
arguments to :func:`~heyoka.diff_tensors()`
(`#164 <https://github.com/bluescarni/heyoka.py/pull/164>`__).
Expand Down
4 changes: 3 additions & 1 deletion doc/notebooks/ODEs with parameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(runtime_param)=\n",
"\n",
"ODEs with parameters\n",
"==================\n",
"====================\n",
"\n",
"The values of numerical constants in heyoka.py can either be specified when constructing an ODE system, or they can be loaded at a later stage when the ODE system is being integrated. The latter type of numerical constant is known as a *parameter*.\n",
"\n",
Expand Down
2 changes: 2 additions & 0 deletions doc/notebooks/The expression system.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"id": "unable-biodiversity",
"metadata": {},
"source": [
"(ex_system)=\n",
"\n",
"The expression system\n",
"===================\n",
"\n",
Expand Down
74 changes: 49 additions & 25 deletions doc/notebooks/compiled_functions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"array([-24.])"
]
},
"execution_count": 8,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -157,7 +157,7 @@
"array([-24.])"
]
},
"execution_count": 10,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -246,9 +246,20 @@
"id": "4bea5b7b-ab72-47cf-a2dd-7f944248992b",
"metadata": {},
"source": [
"```{warning}\n",
"\n",
"The ability to pass lists and other iterables as input/output arguments\n",
"to a compiled function is offered as a convenience, but it incurs into a runtime\n",
"cost as compiled functions need to convert the iterables into NumPy arrays\n",
"on-the-fly before performing the evaluation.\n",
"\n",
"For optimal performance, consider passing NumPy arrays to the call operator\n",
"of compiled functions in order to avoid this hidden cost.\n",
"```\n",
"\n",
"## Functions with parameters\n",
"\n",
"It the compiled function references external parameters, the parameters array will have to be supplied during evaluation via the ``pars`` keyword argument:"
"It the compiled function references [external parameters](<./ODEs with parameters.ipynb>), the parameters array will have to be supplied during evaluation via the ``pars`` keyword argument:"
]
},
{
Expand Down Expand Up @@ -318,7 +329,7 @@
"cf_tm = hy.cfunc([sym_func_tm], [x, y])\n",
"\n",
"# Evaluate for x=1, y=5 and time=6.\n",
"cf_tm([1,5], time=6)"
"cf_tm([1,5], time=6.)"
]
},
{
Expand Down Expand Up @@ -473,7 +484,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2.12 s ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"2.16 s ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -533,14 +544,35 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 20,
"id": "558ee9ab-d6e8-4d48-8556-46966eef9d30",
"metadata": {},
"outputs": [],
"source": [
"Ham_cf = hy.cfunc([Ham_sym], vars=[px,py,pz,x,y,z])"
]
},
{
"cell_type": "markdown",
"id": "dde49007-40f5-4623-8d09-4f77e1baf5e2",
"metadata": {},
"source": [
"heyoka.py's compiled functions support multithreaded parallelisation for batched evaluations. However, for this simple test, we will be disabling multithreaded parallelisation:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "ccb9d544-2b16-492b-be62-969b0f79aaad",
"metadata": {},
"outputs": [],
"source": [
"# Disable parallel batched evaluations by\n",
"# setting the number of threads in use by\n",
"# heyoka.py to 1.\n",
"hy.set_nthreads(1)"
]
},
{
"cell_type": "markdown",
"id": "216ded35-32ae-4a42-9c7a-a5f06b00a952",
Expand All @@ -551,15 +583,15 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 22,
"id": "32131124-52a4-4b92-97fb-3fccd9a77157",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"209 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"213 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand All @@ -577,21 +609,21 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 27,
"id": "59999b6a-e45a-4673-b239-e8b76a6c9be0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"178 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"179 ms ± 3.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"# Pre-allocate the outputs array.\n",
"outputs = np.zeros((1, nevals))\n",
"outputs = np.empty((1, nevals))\n",
"%timeit Ham_cf(inputs,outputs=outputs)"
]
},
Expand All @@ -604,25 +636,17 @@
"\n",
"### JAX\n",
"\n",
"As a last benchmark, we will be performing the same evaluation with [JAX](https://jax.readthedocs.io/en/latest/index.html). Similarly to heyoka.py, JAX offers the possibility to [JIT compile Python functions](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#using-jit-to-speed-up-functions), so we expect similar performance to heyoka.py. Note that, in order to perform a fair comparison, for the execution of this notebook we [enabled 64-bit floats in JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) and we used JAX's CPU backend [forcing a single thread of execution](https://github.com/google/jax/issues/1539) (JAX by default uses multiple threads of execution, but heyoka.py's compiled functions do not yet support multithreaded execution).\n",
"As a last benchmark, we will be performing the same evaluation with [JAX](https://jax.readthedocs.io/en/latest/index.html). Similarly to heyoka.py, JAX offers the possibility to [JIT compile Python functions](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#using-jit-to-speed-up-functions), so we expect similar performance to heyoka.py. Note that, in order to perform a fair comparison, for the execution of this notebook we [enabled 64-bit floats in JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) and we used JAX's CPU backend [forcing a single thread of execution](https://github.com/google/jax/issues/1539).\n",
"\n",
"Let us see the jax code:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 28,
"id": "e5e3a887-fc7a-4df4-8081-bce7db5db23f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -662,15 +686,15 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 30,
"id": "9b9af294-27a3-44f9-8678-208b30f405b3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"304 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"313 ms ± 1.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -703,7 +727,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
Loading
Loading