diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index 0408a6d2a6..02649b57b9 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -268,6 +268,8 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): coarse._fine = context context._coarse = coarse + solutiondm = context._problem.u.function_space().dm + parentdm = get_parent(solutiondm) # Now that we have the coarse snescontext, push it to the coarsened DMs # Otherwise they won't have the right transfer manager when they are # coarsened in turn @@ -275,13 +277,12 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): if isinstance(val, (firedrake.Function, firedrake.Cofunction)): V = val.function_space() coarseneddm = V.dm - parentdm = get_parent(context._problem.u.function_space().dm) # Now attach the hook to the parent DM if get_appctx(coarseneddm) is None: push_appctx(coarseneddm, coarse) - teardown = partial(pop_appctx, coarseneddm, coarse) - add_hook(parentdm, teardown=teardown) + if parentdm.getAttr("__setup_hooks__"): + add_hook(parentdm, teardown=partial(pop_appctx, coarseneddm, coarse)) ises = problem.J.arguments()[0].function_space()._ises coarse._nullspace = self(context._nullspace, self, coefficient_mapping=coefficient_mapping) @@ -384,7 +385,8 @@ def create_interpolation(dmc, dmf): mat.setType(mat.Type.PYTHON) mat.setPythonContext(ctx) mat.setUp() - return mat, None + rscale = mat.createVecLeft() if row_size == col_size else None + return mat, rscale def create_injection(dmc, dmf): diff --git a/tests/multigrid/test_transfer_manager.py b/tests/multigrid/test_transfer_manager.py index d11bceb41f..a3049af242 100644 --- a/tests/multigrid/test_transfer_manager.py +++ b/tests/multigrid/test_transfer_manager.py @@ -1,5 +1,6 @@ import pytest import numpy +import warnings from firedrake import * from firedrake.mg.ufl_utils import coarsen from firedrake.utils import complex_mode @@ -131,3 +132,25 @@ def test_transfer_manager_dat_version_cache(action, transfer_op, spaces): else: raise ValueError(f"Unrecognized action {action}") + + +@pytest.mark.parametrize("family, degree", [("CG", 1), ("R", 0)]) +def test_cached_transfer(family, degree): + # Test that we can properly reuse transfers within solve + sp = {"mat_type": "matfree", + "pc_type": "mg", + "mg_coarse_pc_type": "none", + "mg_levels_pc_type": "none"} + + base = UnitSquareMesh(1, 1) + hierarchy = MeshHierarchy(base, 3) + mesh = hierarchy[-1] + + V = FunctionSpace(mesh, family, degree) + u = Function(V) + F = inner(u - 1, TestFunction(V)) * dx + + # This test will fail if we raise this warning + with warnings.catch_warnings(): + warnings.filterwarnings("error", "Creating new TransferManager", RuntimeWarning) + solve(F == 0, u, solver_parameters=sp)