|
1 | 1 | import itertools
|
2 |
| -import os |
3 |
| -import tempfile |
4 | 2 | import weakref
|
5 | 3 | from collections import OrderedDict, defaultdict
|
6 | 4 | from functools import singledispatch
|
|
13 | 11 | from gem.node import MemoizerArg
|
14 | 12 | from gem.node import traversal as gem_traversal
|
15 | 13 | from pyop2 import op2
|
16 |
| -from pyop2.caching import disk_cached |
| 14 | +from pyop2.caching import cached |
17 | 15 | from pyop2.parloop import GlobalLegacyArg, DatLegacyArg
|
18 | 16 | from tsfc import ufl2gem
|
19 | 17 | from tsfc.loopy import generate
|
@@ -416,25 +414,19 @@ def compile_to_gem(expr, translator):
|
416 | 414 | return preprocess_gem([lvalue, rvalue])
|
417 | 415 |
|
418 | 416 |
|
419 |
| -try: |
420 |
| - _cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"] |
421 |
| -except KeyError: |
422 |
| - _cachedir = os.path.join(tempfile.gettempdir(), |
423 |
| - f"firedrake-pointwise-expression-kernel-cache-uid{os.getuid()}") |
424 |
| -"""Storage location for the kernel cache.""" |
| 417 | +_pointwise_expression_cache = {} |
| 418 | +"""In-memory cache for pointwise expression kernels.""" |
425 | 419 |
|
426 | 420 |
|
427 | 421 | def _pointwise_expression_key(exprs, scalar_type, is_logging):
|
428 | 422 | """Return a cache key for use with :func:`pointwise_expression_kernel`."""
|
429 |
| - # Since this cache is collective this function must return a 2-tuple of |
430 |
| - # communicator and cache key. |
431 |
| - comm = exprs[0].lvalue.node_set.comm |
432 |
| - key = tuple(e.slow_key for e in exprs) + (scalar_type, is_logging) |
433 |
| - return comm, key |
| 423 | + from firedrake.interpolation import hash_expr |
| 424 | + return (tuple((e.__class__, hash(e.lvalue), hash_expr(e.rvalue)) for e in exprs) |
| 425 | + + (scalar_type, is_logging)) |
434 | 426 |
|
435 | 427 |
|
436 | 428 | @PETSc.Log.EventDecorator()
|
437 |
| -@disk_cached({}, _cachedir, key=_pointwise_expression_key, collective=True) |
| 429 | +@cached(_pointwise_expression_cache, key=_pointwise_expression_key) |
438 | 430 | def pointwise_expression_kernel(exprs, scalar_type, is_logging):
|
439 | 431 | """Compile a kernel for pointwise expressions.
|
440 | 432 |
|
|
0 commit comments