Skip to content

Commit ee720d6

Browse files
authored
Merge pull request #2523 from firedrakeproject/connorjward/fix-assign-hashing
Fix hashing for Assign rvalues
2 parents 8b15438 + b75c1a6 commit ee720d6

File tree

3 files changed

+29
-17
lines changed

3 files changed

+29
-17
lines changed

firedrake/assemble_expressions.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import itertools
2-
import os
3-
import tempfile
42
import weakref
53
from collections import OrderedDict, defaultdict
64
from functools import singledispatch
@@ -13,7 +11,7 @@
1311
from gem.node import MemoizerArg
1412
from gem.node import traversal as gem_traversal
1513
from pyop2 import op2
16-
from pyop2.caching import disk_cached
14+
from pyop2.caching import cached
1715
from pyop2.parloop import GlobalLegacyArg, DatLegacyArg
1816
from tsfc import ufl2gem
1917
from tsfc.loopy import generate
@@ -416,25 +414,19 @@ def compile_to_gem(expr, translator):
416414
return preprocess_gem([lvalue, rvalue])
417415

418416

419-
try:
420-
_cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"]
421-
except KeyError:
422-
_cachedir = os.path.join(tempfile.gettempdir(),
423-
f"firedrake-pointwise-expression-kernel-cache-uid{os.getuid()}")
424-
"""Storage location for the kernel cache."""
417+
_pointwise_expression_cache = {}
418+
"""In-memory cache for pointwise expression kernels."""
425419

426420

427421
def _pointwise_expression_key(exprs, scalar_type, is_logging):
428422
"""Return a cache key for use with :func:`pointwise_expression_kernel`."""
429-
# Since this cache is collective this function must return a 2-tuple of
430-
# communicator and cache key.
431-
comm = exprs[0].lvalue.node_set.comm
432-
key = tuple(e.slow_key for e in exprs) + (scalar_type, is_logging)
433-
return comm, key
423+
from firedrake.interpolation import hash_expr
424+
return (tuple((e.__class__, hash(e.lvalue), hash_expr(e.rvalue)) for e in exprs)
425+
+ (scalar_type, is_logging))
434426

435427

436428
@PETSc.Log.EventDecorator()
437-
@disk_cached({}, _cachedir, key=_pointwise_expression_key, collective=True)
429+
@cached(_pointwise_expression_cache, key=_pointwise_expression_key)
438430
def pointwise_expression_kernel(exprs, scalar_type, is_logging):
439431
"""Compile a kernel for pointwise expressions.
440432

firedrake/interpolation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def _compile_expression_key(comm, expr, to_element, ufl_element, domain, paramet
397397
# Since the caching is collective, this function must return a 2-tuple of
398398
# the form (comm, key) where comm is the communicator the cache is collective over.
399399
# FIXME FInAT elements are not safely hashable so we ignore them here
400-
key = _hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log
400+
key = hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log
401401
return comm, key
402402

403403

@@ -517,7 +517,7 @@ def __init__(self, glob):
517517
self.ufl_domain = lambda: None
518518

519519

520-
def _hash_expr(expr):
520+
def hash_expr(expr):
521521
"""Return a numbering-invariant hash of a UFL expression.
522522
523523
:arg expr: A UFL expression.

tests/regression/test_expressions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,26 @@ def test_expression_cache():
499499
assert len(u._expression_cache) == 5
500500

501501

502+
def test_global_expression_cache():
503+
from firedrake.assemble_expressions import _pointwise_expression_cache
504+
505+
mesh = UnitSquareMesh(1, 1)
506+
V = VectorFunctionSpace(mesh, "CG", 1)
507+
u = Function(V)
508+
509+
_pointwise_expression_cache.clear()
510+
assert len(_pointwise_expression_cache) == 0
511+
512+
u.assign(Constant(1))
513+
assert len(_pointwise_expression_cache) == 1
514+
515+
u.assign(Constant(2))
516+
assert len(_pointwise_expression_cache) == 1
517+
518+
u.assign(1)
519+
assert len(_pointwise_expression_cache) == 2
520+
521+
502522
def test_augmented_assignment_broadcast():
503523
mesh = UnitSquareMesh(1, 1)
504524
V = FunctionSpace(mesh, "BDM", 1)

0 commit comments

Comments
 (0)