Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
dc14436
Does this should be optional?
Ig-dolci Feb 14, 2025
c838418
A hack before understand argument_slots
Ig-dolci Feb 14, 2025
e6d87f6
Checking argument_slots; Adding docs; testing actions results for pyt…
Ig-dolci Feb 17, 2025
ee001ce
More docs
Ig-dolci Feb 17, 2025
7278949
Fix docs
Ig-dolci Feb 17, 2025
8249bf4
Enable CI tests for optional installed packages
Ig-dolci Feb 18, 2025
96fc8e5
docs
Ig-dolci Feb 18, 2025
e48501c
docs
Ig-dolci Feb 18, 2025
09003f6
Test ufl Pablo's fixing
Ig-dolci Feb 19, 2025
df1ee6b
merge master
Ig-dolci Feb 19, 2025
3aa5939
More assert; Increase tolerance to reduce test runtime
Ig-dolci Feb 19, 2025
1469403
just a test
Ig-dolci Feb 19, 2025
aca9860
install jaxlib
Ig-dolci Feb 19, 2025
ef6e4d2
Test another jax install requirement for linux
Ig-dolci Feb 19, 2025
d52431c
wip
Ig-dolci Feb 19, 2025
a50cea8
wip
Ig-dolci Feb 19, 2025
a094d88
wip
Ig-dolci Feb 19, 2025
22be7c4
wip
Ig-dolci Feb 19, 2025
6d3b726
será?
Ig-dolci Feb 19, 2025
f11a99e
será?
Ig-dolci Feb 19, 2025
b5f7080
test
Ig-dolci Feb 20, 2025
24c757c
test
Ig-dolci Feb 20, 2025
fe649de
Use options={riesz_representation: None}
Ig-dolci Feb 20, 2025
65289a2
Jax install fixing
Ig-dolci Feb 20, 2025
2c4f227
Merge branch 'master' into dolci/jax_ml_operator
Ig-dolci Feb 20, 2025
fa4c380
Update scripts/firedrake-install
Ig-dolci Feb 20, 2025
8b50efd
Switch ufl branch
pbrubeck Feb 21, 2025
0e00379
review
Ig-dolci Feb 21, 2025
6c477d8
review
Ig-dolci Feb 21, 2025
d613eb2
review
Ig-dolci Feb 21, 2025
6762c85
Update .github/workflows/build.yml
pbrubeck Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions firedrake/external_operators/abstract_external_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,22 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
Parameters
----------
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
Operands of the external operator.
Operands of the external operator.
function_space : firedrake.functionspaceimpl.WithGeometryBase
The function space the external operator is mapping to.
The function space the external operator is mapping to.
derivatives : tuple
Tuple specifiying the derivative multiindex.
Tuple specifiying the derivative multiindex.
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
Tuple containing the arguments of the linear form associated with the external operator,
i.e. the arguments with respect to which the external operator is linear. Those arguments
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
as a result of taking the action on a given function.
Tuple containing the arguments of the linear form associated with the external operator,
i.e. the arguments with respect to which the external operator is linear. Those arguments can
be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
of taking the action on a given function.
operator_data : dict
Dictionary containing the data of the external operator, i.e. the external data
specific to the external operator subclass considered. This dictionary will be passed on
over the UFL symbolic reconstructions making the operator data accessible to the external operators
arising from symbolic operations on the original operator, such as the Jacobian of the external operator.
Dictionary containing the data of the external operator, i.e. the external data
specific to the external operator subclass considered. This dictionary will be passed on
over the UFL symbolic reconstructions making the operator data accessible to the external operators
arising from symbolic operations on the original operator, such as the Jacobian of the external operator.
"""
from firedrake_citations import Citations
Citations().register("Bouziani2021")
Expand Down
27 changes: 15 additions & 12 deletions firedrake/external_operators/ml_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
Parameters
----------
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
Operands of the ML operator.
Operands of the ML operator.
function_space : firedrake.functionspaceimpl.WithGeometryBase
The function space the ML operator is mapping to.
The function space the ML operator is mapping to.
derivatives : tuple
Tuple specifiying the derivative multiindex.
Tuple specifiying the derivative multiindex.
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
Tuple containing the arguments of the linear form associated with the ML operator,
i.e. the arguments with respect to which the ML operator is linear. Those arguments
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
as a result of taking the action on a given function.
Tuple containing the arguments of the linear form associated with the ML operator,
i.e. the arguments with respect to which the ML operator is linear. Those arguments can
be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
of taking the action on a given function. If argument slots are not provided, then they will
be generated in the :class:`.AbstractExternalOperator` constructor.
operator_data : dict
Dictionary to stash external data specific to the ML operator. This dictionary must
at least contain the following:
(i) 'model': The machine learning model implemented in the ML framework considered.
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
Dictionary to stash external data specific to the ML operator. This dictionary must
at least contain the following:
(i) 'model': The machine learning model implemented in the ML framework considered.
(ii) 'inputs_format': The format of the inputs to the ML model: ``0`` for models acting globally
on the inputs, ``1`` when acting locally/pointwise on the inputs.
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
"""
AbstractExternalOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
argument_slots=argument_slots, operator_data=operator_data)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/ml/jax/fem_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def bwd(self, _, grad_output: "jax.Array") -> "jax.Array":
adj_input = float(adj_input)

# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
adj_output = self.F.derivative(adj_input=adj_input)
adj_output = self.F.derivative(adj_input=adj_input, options={'riesz_representation': None})

# Tuplify adjoint output
adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output
Expand Down
43 changes: 23 additions & 20 deletions firedrake/ml/jax/ml_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,36 +39,40 @@ def __init__(
*operands: Union[ufl.core.expr.Expr, ufl.form.BaseForm],
function_space: WithGeometryBase,
derivatives: Optional[tuple] = None,
argument_slots: Optional[tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]]],
argument_slots: tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]] = (),
operator_data: Optional[dict] = {}
):
"""External operator class representing machine learning models implemented in JAX.
"""
External operator class representing machine learning models implemented in JAX.

The :class:`.JaxOperator` allows users to embed machine learning models implemented in JAX
into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator` is
delegated to the specified JAX model. Similarly, differentiation through the :class:`.JaxOperator`
class is achieved using JAX differentiation on the JAX model associated with the :class:`.JaxOperator` object.
into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator`
is delegated to the specified JAX model. Similarly, differentiation through the
:class:`.JaxOperator` is achieved using JAX differentiation on the associated JAX model.

Parameters
----------
*operands
Operands of the :class:`.JaxOperator`.
Operands of the :class:`.JaxOperator`.
function_space
The function space the ML operator is mapping to.
The function space the ML operator is mapping to.
derivatives
Tuple specifiying the derivative multiindex.
*argument_slots
Tuple containing the arguments of the linear form associated with the ML operator,
i.e. the arguments with respect to which the ML operator is linear. Those arguments
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
as a result of taking the action on a given function.
Tuple specifying the derivative multi-index.
argument_slots
Tuple containing the arguments of the linear form associated with the ML operator,
i.e., the arguments with respect to which the ML operator is linear. These arguments
can be ``ufl.argument.BaseArgument`` objects, as a result of differentiation,
or both ``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` objects,
as a result of taking the action on a given function.
operator_data
Dictionary to stash external data specific to the ML operator. This dictionary must
at least contain the following:
(i) 'model': The machine learning model implemented in JaX
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
Other strategies can also be considered by subclassing the :class:`.JaxOperator` class.
Dictionary to stash external data specific to the ML operator. This dictionary must
contain the following:
(i) ``'model'`` : The machine learning model implemented in JaX.
(ii) ``'model'`` : The format of the inputs to the ML model: ``0`` for models acting
globally on the inputs. ``1`` for models acting locally/pointwise on the inputs.
Other strategies can also be considered by subclassing the :class:`.JaxOperator` class.
"""

MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
argument_slots=argument_slots, operator_data=operator_data)

Expand All @@ -90,8 +94,7 @@ def _pre_forward_callback(self, *operands: Union[Function, Cofunction], unsqueez

def _post_forward_callback(self, y_P: "jax.Array") -> Union[Function, Cofunction]:
"""Callback function to convert the JAX output of the ML model to a Firedrake function."""
space = self.ufl_function_space()
return from_jax(y_P, space)
return from_jax(y_P, self.ufl_function_space())

# -- JAX routines for computing AD-based quantities -- #

Expand Down
2 changes: 1 addition & 1 deletion firedrake/ml/pytorch/fem_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def backward(ctx, grad_output):
adj_input = float(adj_input)

# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "l2"})
adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": None})

# Tuplify adjoint output
adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output
Expand Down
29 changes: 15 additions & 14 deletions firedrake/ml/pytorch/ml_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,24 @@ class is achieved via the `torch.autograd` module, which provides automatic diff
Parameters
----------
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
Operands of the :class:`.PytorchOperator`.
Operands of the :class:`.PytorchOperator`.
function_space : firedrake.functionspaceimpl.WithGeometryBase
The function space the ML operator is mapping to.
The function space the ML operator is mapping to.
derivatives : tuple
Tuple specifiying the derivative multiindex.
Tuple specifiying the derivative multiindex.
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
Tuple containing the arguments of the linear form associated with the ML operator,
i.e. the arguments with respect to which the ML operator is linear. Those arguments
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
as a result of taking the action on a given function.
Tuple containing the arguments of the linear form associated with the ML operator, i.e. the
arguments with respect to which the ML operator is linear. Those arguments can be
``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
of taking the action on a given function.
operator_data : dict
Dictionary to stash external data specific to the ML operator. This dictionary must
at least contain the following:
(i) 'model': The machine learning model implemented in PyTorch.
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class.
Dictionary to stash external data specific to the ML operator. This dictionary must
at least contain the following:
(i) ``'model'``: The machine learning model implemented in PyTorch.
(ii) ``'inputs_format'``: The format of the inputs to the ML model: ``0`` for models acting globally
on the inputs, ``1`` when acting locally/pointwise on the inputs.
Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class.
"""
MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
argument_slots=argument_slots, operator_data=operator_data)
Expand Down Expand Up @@ -98,8 +100,7 @@ def _pre_forward_callback(self, *operands, unsqueeze=False):

def _post_forward_callback(self, y_P):
"""Callback function to convert the PyTorch output of the ML model to a Firedrake function."""
space = self.ufl_function_space()
return from_torch(y_P, space)
return from_torch(y_P, self.ufl_function_space())

# -- PyTorch routines for computing AD based quantities via `torch.autograd.functional` -- #

Expand Down
2 changes: 1 addition & 1 deletion scripts/firedrake-install
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ def build_and_install_jax():
"""Install JAX for a CPU or CUDA backend."""
log.info("Installing JAX (backend: %s)" % args.jax)
version_name = "jax" if args.jax == "cpu" else "jax[cuda12]"
run_pip_install([version_name])
run_pip_install([version_name] + ["jaxlib"] + ["ml_dtypes"] + ["opt_einsum"])


def build_and_install_slepc():
Expand Down
40 changes: 21 additions & 19 deletions tests/firedrake/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,27 @@ def pytest_collection_modifyitems(session, config, items):
if item.get_closest_marker("skipmumps") is not None:
item.add_marker(pytest.mark.skip("MUMPS not installed with PETSc"))

if not torch_backend:
if item.get_closest_marker("skiptorch") is not None:
item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed"))

if not jax_backend:
if item.get_closest_marker("skipjax") is not None:
item.add_marker(pytest.mark.skip(reason="Test makes no sense if JAX is not installed"))

if not matplotlib_installed:
if item.get_closest_marker("skipplot") is not None:
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Matplotlib is installed"))

if not netgen_installed:
if item.get_closest_marker("skipnetgen") is not None:
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Netgen and ngsPETSc are installed"))

if not vtk_installed:
if item.get_closest_marker("skipvtk") is not None:
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless VTK is installed"))
import os
if os.getenv("FIREDRAKE_CI_TESTS") != "1":
if not torch_backend:
if item.get_closest_marker("skiptorch") is not None:
item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed"))

if not jax_backend:
if item.get_closest_marker("skipjax") is not None:
item.add_marker(pytest.mark.skip(reason="Test makes no sense if JAX is not installed"))

if not matplotlib_installed:
if item.get_closest_marker("skipplot") is not None:
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Matplotlib is installed"))

if not netgen_installed:
if item.get_closest_marker("skipnetgen") is not None:
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Netgen and ngsPETSc are installed"))

if not vtk_installed:
if item.get_closest_marker("skipvtk") is not None:
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless VTK is installed"))


@pytest.fixture(scope="module", autouse=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,6 @@ def J(f):
c = Control(f)
Jhat = ReducedFunctional(J(f), c)

f_opt = minimize(Jhat, tol=1e-6, method="BFGS")
f_opt = minimize(Jhat, tol=1e-4, method="BFGS")

assert assemble((f_exact - f_opt)**2 * dx) / assemble(f_exact**2 * dx) < 1e-5
5 changes: 5 additions & 0 deletions tests/firedrake/external_operators/test_jax_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def test_forward(u, nn):

# Assemble NeuralNet operator
assembled_N = assemble(N)
assert isinstance(assembled_N, Function)

# Convert from Firedrake to JAX
x_P = to_jax(u)
Expand Down Expand Up @@ -125,6 +126,8 @@ def test_jvp(u, nn):
# Assemble
dN = assemble(dN)

assert isinstance(dN, Function)

# Convert from Firedrake to JAX
delta_u_P = to_jax(delta_u)
u_P = to_jax(u)
Expand Down Expand Up @@ -153,6 +156,8 @@ def test_vjp(u, nn):
# Assemble
dN_adj = assemble(dNdu)

assert isinstance(dN_adj, Cofunction)

# Convert from Firedrake to JAX
delta_N_P = to_jax(delta_N)
u_P = to_jax(u)
Expand Down
6 changes: 6 additions & 0 deletions tests/firedrake/external_operators/test_pytorch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def test_forward(u, nn):
# Assemble NeuralNet operator
assembled_N = assemble(N)

assert isinstance(assembled_N, Function)

# Convert from Firedrake to PyTorch
x_P = to_torch(u)
# Forward pass
Expand Down Expand Up @@ -118,6 +120,8 @@ def test_jvp(u, nn):
# Assemble
dN = assemble(dN)

assert isinstance(dN, Function)

# Convert from Firedrake to PyTorch
delta_u_P = to_torch(delta_u)
u_P = to_torch(u)
Expand Down Expand Up @@ -146,6 +150,8 @@ def test_vjp(u, nn):
# Assemble
dN_adj = assemble(dNdu)

assert isinstance(dN_adj, Cofunction)

# Convert from Firedrake to PyTorch
delta_N_P = to_torch(delta_N)
u_P = to_torch(u)
Expand Down
Loading