From af330388d301b3a88718cbfd5eb0c0b4dff72d53 Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Sat, 18 Oct 2025 14:12:36 +0100 Subject: [PATCH] fem: cache on instances Co-authored-by: Leo Collins Co-authored-by: Pablo Brubeck --- tsfc/fem.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/tsfc/fem.py b/tsfc/fem.py index 0399c5a421..7c11096062 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -284,15 +284,6 @@ def get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme): return make_quadrature(integration_cell, quadrature_degree, scheme=scheme) -def make_basis_evaluation_key(ctx, finat_element, mt, entity_id): - ufl_element = mt.terminal.ufl_element() - domain = extract_unique_domain(mt.terminal) - coordinate_element = domain.ufl_coordinate_element() - # This way of caching is fragile. - # Should Implement _hash_key_() for ModifiedTerminal and use the entire mt as key. - return (ufl_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction, domain._ufl_hash_data_()) - - class PointSetContext(ContextBase): """Context for compile-time known evaluation points.""" @@ -323,12 +314,32 @@ def point_expr(self): def weight_expr(self): return self.quadrature_rule.weight_expression - @serial_cache(hashkey=make_basis_evaluation_key) + @staticmethod + def _make_basis_evaluation_key(finat_element, mt, entity_id): + ufl_element = mt.terminal.ufl_element() + domain = extract_unique_domain(mt.terminal) + coordinate_element = domain.ufl_coordinate_element() + # This way of caching is fragile. + # Should implement _hash_key_() in ModifiedTerminal and include the entire mt in the key, + # or only pass necessary bits in mt to basis_evaluation. + return (ufl_element, mt.local_derivatives, entity_id, coordinate_element, mt.restriction, domain._ufl_hash_data_()) + + @cached_property + def _basis_evaluation_cache(self): + return {} + def basis_evaluation(self, finat_element, mt, entity_id): - return finat_element.basis_evaluation(mt.local_derivatives, - self.point_set, - (self.integration_dim, entity_id), - coordinate_mapping=CoordinateMapping(mt, self)) + key = PointSetContext._make_basis_evaluation_key(finat_element, mt, entity_id) + try: + return self._basis_evaluation_cache[key] + except KeyError: + val = finat_element.basis_evaluation( + mt.local_derivatives, + self.point_set, + (self.integration_dim, entity_id), + coordinate_mapping=CoordinateMapping(mt, self), + ) + return self._basis_evaluation_cache.setdefault(key, val) class GemPointContext(ContextBase):