A short course on JAX, automatic differentiation, the adjoint state method and differentiable simulators.
Author: Dr Aidan Crilly ([email protected])
Course details:
Date: Wednesday 25th June
Short description: The course aims to describe how to add trainable data-driven terms to physics-based differential equation models.
Length: Half-day (4-5 hours with break), starting 10am
Pre-requisites: Some Python knowledge (Imperial Physics undergraduate course level) Bring a laptop, course will be on Google colab (so just browser based and cloud computing so doesn't need to be good laptop).
Learning outcomes: The course aims to teach how we can use automatic differentiation (a tool developed for machine learning) to gain linear sensitivity information about ODE and PDE numerical solutions. We can then use this information to add trainable terms to these models. For example, learn the thermal conductivity from temperature data and the heat equation. To do this you will learn the basic use of the following libraries: JAX, diffrax, optax, equinox
Course Materials
The course includes lecture material and four Google colab notebooks:
-
Computational graphs and implementing adjoint methods by hand:
-
Use of differentiable programming libraries to create differentiable simulators:
This is based on a similar course given to the 'AI in Sciences' masters program at the African Institute of Mathematical Sciences (AIMS).