This package provides a set of benchmark scripts that can be used to profile JAX performance on a varying number of CPU cores. JAX does not provide control over the number of cores it uses, so a common trick is to work do this with taskset
.
The benchmarks can be run by installing the package with pip
and running it as follows:
python3 -m pip install git+https://github.com/ComPWA/jax-mini-benchmark@main
benchmark-jax
The resulting benchmark can be viewed in jax-benchmark-$HOSTNAME.svg
. If you do not want to view the resulting plot directly, like when you run this command in a script, add the --no-show
flag:
benchmark-jax --no-show
We recommend working with a virtual environment (more info here). If you have installed Miniconda, the project can easily be set up as follows:
git clone https://github.com/ComPWA/jax-mini-benchmark
cd jax-mini-benchmark
conda env create
conda activate jax-mini-benchmark
pre-commit install # optional, but recommended
See ComPWA's Help developing for more info.