diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 189d810..f1e5881 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.6 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -348,7 +348,15 @@ The compiled versions for the previous array size are still available in memory too, and the following call is dispatched to the correct compiled code. ```{code-cell} ipython3 -%time f(x).block_until_ready() +%timeit f(x).block_until_ready() +``` + +```{note} +Note that we use the [`%timeit` magic](https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-timeit) +here to get more reliable timing statistics. + +It runs the code multiple times to calculate both mean and standard deviation, +which takes longer but provides more accurate measurements than a single run. ``` ### Compiling the outer function @@ -368,14 +376,13 @@ f_jit(x) And now let's time it. ```{code-cell} ipython3 -%time f_jit(x).block_until_ready() +%timeit f_jit(x).block_until_ready() ``` Note the speed gain. This is because the array operations are fused and no intermediate arrays are created. - Incidentally, a more common syntax when targetting a function for the JIT compiler is