Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,6 @@ def get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme):
return make_quadrature(integration_cell, quadrature_degree, scheme=scheme)


def make_basis_evaluation_key(ctx, finat_element, mt, entity_id):
ufl_element = mt.terminal.ufl_element()
domain = extract_unique_domain(mt.terminal)
coordinate_element = domain.ufl_coordinate_element()
# This way of caching is fragile.
# Should Implement _hash_key_() for ModifiedTerminal and use the entire mt as key.
return (ufl_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction, domain._ufl_hash_data_())


class PointSetContext(ContextBase):
"""Context for compile-time known evaluation points."""

Expand Down Expand Up @@ -323,12 +314,32 @@ def point_expr(self):
def weight_expr(self):
return self.quadrature_rule.weight_expression

@serial_cache(hashkey=make_basis_evaluation_key)
@staticmethod
def _make_basis_evaluation_key(finat_element, mt, entity_id):
ufl_element = mt.terminal.ufl_element()
domain = extract_unique_domain(mt.terminal)
coordinate_element = domain.ufl_coordinate_element()
# This way of caching is fragile.
# Should implement _hash_key_() in ModifiedTerminal and include the entire mt in the key,
# or only pass necessary bits in mt to basis_evaluation.
return (ufl_element, mt.local_derivatives, entity_id, coordinate_element, mt.restriction, domain._ufl_hash_data_())

@cached_property
def _basis_evaluation_cache(self):
return {}

def basis_evaluation(self, finat_element, mt, entity_id):
return finat_element.basis_evaluation(mt.local_derivatives,
self.point_set,
(self.integration_dim, entity_id),
coordinate_mapping=CoordinateMapping(mt, self))
key = PointSetContext._make_basis_evaluation_key(finat_element, mt, entity_id)
try:
return self._basis_evaluation_cache[key]
except KeyError:
val = finat_element.basis_evaluation(
mt.local_derivatives,
self.point_set,
(self.integration_dim, entity_id),
coordinate_mapping=CoordinateMapping(mt, self),
)
return self._basis_evaluation_cache.setdefault(key, val)


class GemPointContext(ContextBase):
Expand Down
Loading