diff --git a/requirements.txt b/requirements.txt index ad9b2027e..43efdec59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,5 +38,8 @@ jsonschema>=2.0.0,<3.0.0,!=2.5.0 oslo.utils>=1.4.0 # Apache-2.0 oslo.serialization>=1.4.0 # Apache-2.0 +# For lru caches and such +cachetools>=1.0.0 # MIT License + # For deprecation of things debtcollector>=0.3.0 # Apache-2.0 diff --git a/taskflow/persistence/backends/impl_dir.py b/taskflow/persistence/backends/impl_dir.py index 940b9c418..b6d1a27bf 100644 --- a/taskflow/persistence/backends/impl_dir.py +++ b/taskflow/persistence/backends/impl_dir.py @@ -20,6 +20,7 @@ import os import shutil +import cachetools import fasteners from oslo_serialization import jsonutils @@ -54,12 +55,22 @@ class DirBackend(path_based.PathBasedBackend): Example configuration:: conf = { - "path": "/tmp/taskflow", + "path": "/tmp/taskflow", # save data to this root directory + "max_cache_size": 1024, # keep up-to 1024 entries in memory } """ + def __init__(self, conf): super(DirBackend, self).__init__(conf) - self.file_cache = {} + max_cache_size = self._conf.get('max_cache_size') + if max_cache_size is not None: + max_cache_size = int(max_cache_size) + if max_cache_size < 1: + raise ValueError("Maximum cache size must be greater than" + " or equal to one") + self.file_cache = cachetools.LRUCache(max_cache_size) + else: + self.file_cache = {} self.encoding = self._conf.get('encoding', 'utf-8') if not self._path: raise ValueError("Empty path is disallowed") diff --git a/taskflow/tests/unit/persistence/test_dir_persistence.py b/taskflow/tests/unit/persistence/test_dir_persistence.py index 8c1171cf8..7445145a1 100644 --- a/taskflow/tests/unit/persistence/test_dir_persistence.py +++ b/taskflow/tests/unit/persistence/test_dir_persistence.py @@ -19,37 +19,80 @@ import shutil import tempfile +from oslo_utils import uuidutils +import testscenarios + +from taskflow import exceptions as exc from taskflow.persistence import backends from taskflow.persistence.backends import impl_dir +from taskflow.persistence import logbook from taskflow import test from taskflow.tests.unit.persistence import base -class DirPersistenceTest(test.TestCase, base.PersistenceTestMixin): +class DirPersistenceTest(testscenarios.TestWithScenarios, + test.TestCase, base.PersistenceTestMixin): + + scenarios = [ + ('no_cache', {'max_cache_size': None}), + ('one', {'max_cache_size': 1}), + ('tiny', {'max_cache_size': 256}), + ('medimum', {'max_cache_size': 512}), + ('large', {'max_cache_size': 1024}), + ] + def _get_connection(self): - conf = { - 'path': self.path, - } - return impl_dir.DirBackend(conf).get_connection() + return self.backend.get_connection() def setUp(self): super(DirPersistenceTest, self).setUp() self.path = tempfile.mkdtemp() - conn = self._get_connection() - conn.upgrade() + self.backend = impl_dir.DirBackend({ + 'path': self.path, + 'max_cache_size': self.max_cache_size, + }) + with contextlib.closing(self._get_connection()) as conn: + conn.upgrade() def tearDown(self): super(DirPersistenceTest, self).tearDown() - conn = self._get_connection() - conn.clear_all() if self.path and os.path.isdir(self.path): shutil.rmtree(self.path) self.path = None + self.backend = None def _check_backend(self, conf): with contextlib.closing(backends.fetch(conf)) as be: self.assertIsInstance(be, impl_dir.DirBackend) + def test_dir_backend_invalid_cache_size(self): + for invalid_size in [-1024, 0, -1]: + conf = { + 'path': self.path, + 'max_cache_size': invalid_size, + } + self.assertRaises(ValueError, impl_dir.DirBackend, conf) + + def test_dir_backend_cache_overfill(self): + if self.max_cache_size is not None: + # Ensure cache never goes past the desired max size... + books_ids_made = [] + with contextlib.closing(self._get_connection()) as conn: + for i in range(0, int(1.5 * self.max_cache_size)): + lb_name = 'book-%s' % (i) + lb_id = uuidutils.generate_uuid() + lb = logbook.LogBook(name=lb_name, uuid=lb_id) + self.assertRaises(exc.NotFound, conn.get_logbook, lb_id) + conn.save_logbook(lb) + books_ids_made.append(lb_id) + self.assertLessEqual(self.backend.file_cache.currsize, + self.max_cache_size) + # Also ensure that we can still read all created books... + with contextlib.closing(self._get_connection()) as conn: + for lb_id in books_ids_made: + lb = conn.get_logbook(lb_id) + self.assertIsNotNone(lb) + def test_dir_backend_entry_point(self): self._check_backend(dict(connection='dir:', path=self.path))