diff --git a/firedrake/external_operators/abstract_external_operators.py b/firedrake/external_operators/abstract_external_operators.py index f4e5ecbc21..789e1fccb9 100644 --- a/firedrake/external_operators/abstract_external_operators.py +++ b/firedrake/external_operators/abstract_external_operators.py @@ -47,21 +47,22 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=( Parameters ---------- *operands : ufl.core.expr.Expr or ufl.form.BaseForm - Operands of the external operator. + Operands of the external operator. function_space : firedrake.functionspaceimpl.WithGeometryBase - The function space the external operator is mapping to. + The function space the external operator is mapping to. derivatives : tuple - Tuple specifiying the derivative multiindex. + Tuple specifiying the derivative multiindex. *argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument - Tuple containing the arguments of the linear form associated with the external operator, - i.e. the arguments with respect to which the external operator is linear. Those arguments - can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects, - as a result of taking the action on a given function. + Tuple containing the arguments of the linear form associated with the external operator, + i.e. the arguments with respect to which the external operator is linear. Those arguments can + be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both + ``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result + of taking the action on a given function. operator_data : dict - Dictionary containing the data of the external operator, i.e. the external data - specific to the external operator subclass considered. This dictionary will be passed on - over the UFL symbolic reconstructions making the operator data accessible to the external operators - arising from symbolic operations on the original operator, such as the Jacobian of the external operator. + Dictionary containing the data of the external operator, i.e. the external data + specific to the external operator subclass considered. This dictionary will be passed on + over the UFL symbolic reconstructions making the operator data accessible to the external operators + arising from symbolic operations on the original operator, such as the Jacobian of the external operator. """ from firedrake_citations import Citations Citations().register("Bouziani2021") diff --git a/firedrake/external_operators/ml_operator.py b/firedrake/external_operators/ml_operator.py index ed95b2a6d7..ba35366cbc 100644 --- a/firedrake/external_operators/ml_operator.py +++ b/firedrake/external_operators/ml_operator.py @@ -15,22 +15,25 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=( Parameters ---------- *operands : ufl.core.expr.Expr or ufl.form.BaseForm - Operands of the ML operator. + Operands of the ML operator. function_space : firedrake.functionspaceimpl.WithGeometryBase - The function space the ML operator is mapping to. + The function space the ML operator is mapping to. derivatives : tuple - Tuple specifiying the derivative multiindex. + Tuple specifiying the derivative multiindex. *argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument - Tuple containing the arguments of the linear form associated with the ML operator, - i.e. the arguments with respect to which the ML operator is linear. Those arguments - can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects, - as a result of taking the action on a given function. + Tuple containing the arguments of the linear form associated with the ML operator, + i.e. the arguments with respect to which the ML operator is linear. Those arguments can + be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both + ``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result + of taking the action on a given function. If argument slots are not provided, then they will + be generated in the :class:`.AbstractExternalOperator` constructor. operator_data : dict - Dictionary to stash external data specific to the ML operator. This dictionary must - at least contain the following: - (i) 'model': The machine learning model implemented in the ML framework considered. - (ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs. - Other strategies can also be considered by subclassing the :class:`.MLOperator` class. + Dictionary to stash external data specific to the ML operator. This dictionary must + at least contain the following: + (i) 'model': The machine learning model implemented in the ML framework considered. + (ii) 'inputs_format': The format of the inputs to the ML model: ``0`` for models acting globally + on the inputs, ``1`` when acting locally/pointwise on the inputs. + Other strategies can also be considered by subclassing the :class:`.MLOperator` class. """ AbstractExternalOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives, argument_slots=argument_slots, operator_data=operator_data) diff --git a/firedrake/ml/jax/fem_operator.py b/firedrake/ml/jax/fem_operator.py index 352d837938..b9e1914846 100644 --- a/firedrake/ml/jax/fem_operator.py +++ b/firedrake/ml/jax/fem_operator.py @@ -92,7 +92,7 @@ def bwd(self, _, grad_output: "jax.Array") -> "jax.Array": adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = self.F.derivative(adj_input=adj_input) + adj_output = self.F.derivative(adj_input=adj_input, options={'riesz_representation': None}) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output diff --git a/firedrake/ml/jax/ml_operator.py b/firedrake/ml/jax/ml_operator.py index 6a15cb18f7..9f9d2fb178 100644 --- a/firedrake/ml/jax/ml_operator.py +++ b/firedrake/ml/jax/ml_operator.py @@ -39,36 +39,40 @@ def __init__( *operands: Union[ufl.core.expr.Expr, ufl.form.BaseForm], function_space: WithGeometryBase, derivatives: Optional[tuple] = None, - argument_slots: Optional[tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]]], + argument_slots: tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]] = (), operator_data: Optional[dict] = {} ): - """External operator class representing machine learning models implemented in JAX. + """ + External operator class representing machine learning models implemented in JAX. The :class:`.JaxOperator` allows users to embed machine learning models implemented in JAX - into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator` is - delegated to the specified JAX model. Similarly, differentiation through the :class:`.JaxOperator` - class is achieved using JAX differentiation on the JAX model associated with the :class:`.JaxOperator` object. + into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator` + is delegated to the specified JAX model. Similarly, differentiation through the + :class:`.JaxOperator` is achieved using JAX differentiation on the associated JAX model. Parameters ---------- *operands - Operands of the :class:`.JaxOperator`. + Operands of the :class:`.JaxOperator`. function_space - The function space the ML operator is mapping to. + The function space the ML operator is mapping to. derivatives - Tuple specifiying the derivative multiindex. - *argument_slots - Tuple containing the arguments of the linear form associated with the ML operator, - i.e. the arguments with respect to which the ML operator is linear. Those arguments - can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects, - as a result of taking the action on a given function. + Tuple specifying the derivative multi-index. + argument_slots + Tuple containing the arguments of the linear form associated with the ML operator, + i.e., the arguments with respect to which the ML operator is linear. These arguments + can be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, + or both ``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` objects, + as a result of taking the action on a given function. operator_data - Dictionary to stash external data specific to the ML operator. This dictionary must - at least contain the following: - (i) 'model': The machine learning model implemented in JaX - (ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs. - Other strategies can also be considered by subclassing the :class:`.JaxOperator` class. + Dictionary to stash external data specific to the ML operator. This dictionary must + contain the following: + (i) ``'model'`` : The machine learning model implemented in JaX. + (ii) ``'model'`` : The format of the inputs to the ML model: ``0`` for models acting + globally on the inputs. ``1`` for models acting locally/pointwise on the inputs. + Other strategies can also be considered by subclassing the :class:`.JaxOperator` class. """ + MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives, argument_slots=argument_slots, operator_data=operator_data) @@ -90,8 +94,7 @@ def _pre_forward_callback(self, *operands: Union[Function, Cofunction], unsqueez def _post_forward_callback(self, y_P: "jax.Array") -> Union[Function, Cofunction]: """Callback function to convert the JAX output of the ML model to a Firedrake function.""" - space = self.ufl_function_space() - return from_jax(y_P, space) + return from_jax(y_P, self.ufl_function_space()) # -- JAX routines for computing AD-based quantities -- # diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index 7eda1d07ed..f7956265d0 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -83,7 +83,7 @@ def backward(ctx, grad_output): adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "l2"}) + adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": None}) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output diff --git a/firedrake/ml/pytorch/ml_operator.py b/firedrake/ml/pytorch/ml_operator.py index 84dbdd2f45..2fecf152f1 100644 --- a/firedrake/ml/pytorch/ml_operator.py +++ b/firedrake/ml/pytorch/ml_operator.py @@ -40,22 +40,24 @@ class is achieved via the `torch.autograd` module, which provides automatic diff Parameters ---------- *operands : ufl.core.expr.Expr or ufl.form.BaseForm - Operands of the :class:`.PytorchOperator`. + Operands of the :class:`.PytorchOperator`. function_space : firedrake.functionspaceimpl.WithGeometryBase - The function space the ML operator is mapping to. + The function space the ML operator is mapping to. derivatives : tuple - Tuple specifiying the derivative multiindex. + Tuple specifiying the derivative multiindex. *argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument - Tuple containing the arguments of the linear form associated with the ML operator, - i.e. the arguments with respect to which the ML operator is linear. Those arguments - can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects, - as a result of taking the action on a given function. + Tuple containing the arguments of the linear form associated with the ML operator, i.e. the + arguments with respect to which the ML operator is linear. Those arguments can be + ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both + ``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result + of taking the action on a given function. operator_data : dict - Dictionary to stash external data specific to the ML operator. This dictionary must - at least contain the following: - (i) 'model': The machine learning model implemented in PyTorch. - (ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs. - Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class. + Dictionary to stash external data specific to the ML operator. This dictionary must + at least contain the following: + (i) ``'model'``: The machine learning model implemented in PyTorch. + (ii) ``'inputs_format'``: The format of the inputs to the ML model: ``0`` for models acting globally + on the inputs, ``1`` when acting locally/pointwise on the inputs. + Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class. """ MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives, argument_slots=argument_slots, operator_data=operator_data) @@ -98,8 +100,7 @@ def _pre_forward_callback(self, *operands, unsqueeze=False): def _post_forward_callback(self, y_P): """Callback function to convert the PyTorch output of the ML model to a Firedrake function.""" - space = self.ufl_function_space() - return from_torch(y_P, space) + return from_torch(y_P, self.ufl_function_space()) # -- PyTorch routines for computing AD based quantities via `torch.autograd.functional` -- # diff --git a/scripts/firedrake-install b/scripts/firedrake-install index aaf3f293f9..a0e7343c2c 100755 --- a/scripts/firedrake-install +++ b/scripts/firedrake-install @@ -1303,7 +1303,7 @@ def build_and_install_jax(): """Install JAX for a CPU or CUDA backend.""" log.info("Installing JAX (backend: %s)" % args.jax) version_name = "jax" if args.jax == "cpu" else "jax[cuda12]" - run_pip_install([version_name]) + run_pip_install([version_name] + ["jaxlib"] + ["ml_dtypes"] + ["opt_einsum"]) def build_and_install_slepc(): diff --git a/tests/firedrake/conftest.py b/tests/firedrake/conftest.py index f42b911b91..66e155a823 100644 --- a/tests/firedrake/conftest.py +++ b/tests/firedrake/conftest.py @@ -98,25 +98,27 @@ def pytest_collection_modifyitems(session, config, items): if item.get_closest_marker("skipmumps") is not None: item.add_marker(pytest.mark.skip("MUMPS not installed with PETSc")) - if not torch_backend: - if item.get_closest_marker("skiptorch") is not None: - item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed")) - - if not jax_backend: - if item.get_closest_marker("skipjax") is not None: - item.add_marker(pytest.mark.skip(reason="Test makes no sense if JAX is not installed")) - - if not matplotlib_installed: - if item.get_closest_marker("skipplot") is not None: - item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Matplotlib is installed")) - - if not netgen_installed: - if item.get_closest_marker("skipnetgen") is not None: - item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Netgen and ngsPETSc are installed")) - - if not vtk_installed: - if item.get_closest_marker("skipvtk") is not None: - item.add_marker(pytest.mark.skip(reason="Test cannot be run unless VTK is installed")) + import os + if os.getenv("FIREDRAKE_CI_TESTS") != "1": + if not torch_backend: + if item.get_closest_marker("skiptorch") is not None: + item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed")) + + if not jax_backend: + if item.get_closest_marker("skipjax") is not None: + item.add_marker(pytest.mark.skip(reason="Test makes no sense if JAX is not installed")) + + if not matplotlib_installed: + if item.get_closest_marker("skipplot") is not None: + item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Matplotlib is installed")) + + if not netgen_installed: + if item.get_closest_marker("skipnetgen") is not None: + item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Netgen and ngsPETSc are installed")) + + if not vtk_installed: + if item.get_closest_marker("skipvtk") is not None: + item.add_marker(pytest.mark.skip(reason="Test cannot be run unless VTK is installed")) @pytest.fixture(scope="module", autouse=True) diff --git a/tests/firedrake/external_operators/test_external_operators_adjoint.py b/tests/firedrake/external_operators/test_external_operators_adjoint.py index b7572f3f7e..a563e55115 100644 --- a/tests/firedrake/external_operators/test_external_operators_adjoint.py +++ b/tests/firedrake/external_operators/test_external_operators_adjoint.py @@ -77,6 +77,6 @@ def J(f): c = Control(f) Jhat = ReducedFunctional(J(f), c) - f_opt = minimize(Jhat, tol=1e-6, method="BFGS") + f_opt = minimize(Jhat, tol=1e-4, method="BFGS") assert assemble((f_exact - f_opt)**2 * dx) / assemble(f_exact**2 * dx) < 1e-5 diff --git a/tests/firedrake/external_operators/test_jax_operator.py b/tests/firedrake/external_operators/test_jax_operator.py index 940b84712b..ec9b6f11d6 100644 --- a/tests/firedrake/external_operators/test_jax_operator.py +++ b/tests/firedrake/external_operators/test_jax_operator.py @@ -96,6 +96,7 @@ def test_forward(u, nn): # Assemble NeuralNet operator assembled_N = assemble(N) + assert isinstance(assembled_N, Function) # Convert from Firedrake to JAX x_P = to_jax(u) @@ -125,6 +126,8 @@ def test_jvp(u, nn): # Assemble dN = assemble(dN) + assert isinstance(dN, Function) + # Convert from Firedrake to JAX delta_u_P = to_jax(delta_u) u_P = to_jax(u) @@ -153,6 +156,8 @@ def test_vjp(u, nn): # Assemble dN_adj = assemble(dNdu) + assert isinstance(dN_adj, Cofunction) + # Convert from Firedrake to JAX delta_N_P = to_jax(delta_N) u_P = to_jax(u) diff --git a/tests/firedrake/external_operators/test_pytorch_operator.py b/tests/firedrake/external_operators/test_pytorch_operator.py index 99f11a6990..69e0e1632b 100644 --- a/tests/firedrake/external_operators/test_pytorch_operator.py +++ b/tests/firedrake/external_operators/test_pytorch_operator.py @@ -90,6 +90,8 @@ def test_forward(u, nn): # Assemble NeuralNet operator assembled_N = assemble(N) + assert isinstance(assembled_N, Function) + # Convert from Firedrake to PyTorch x_P = to_torch(u) # Forward pass @@ -118,6 +120,8 @@ def test_jvp(u, nn): # Assemble dN = assemble(dN) + assert isinstance(dN, Function) + # Convert from Firedrake to PyTorch delta_u_P = to_torch(delta_u) u_P = to_torch(u) @@ -146,6 +150,8 @@ def test_vjp(u, nn): # Assemble dN_adj = assemble(dNdu) + assert isinstance(dN_adj, Cofunction) + # Convert from Firedrake to PyTorch delta_N_P = to_torch(delta_N) u_P = to_torch(u)