Skip to content

Commit

Permalink
Improved cache hit/miss counting and added method to get cache keys (…
Browse files Browse the repository at this point in the history
…entries)

Cache miss counting was wrong because the code was checking if a value is in cache
before getting it, which indeed never increase diskcache's cache misses counter.
Also fixed an old warning from numpy where comparing two collection with np.all needs
them to be the same size.

Signed-off-by: Alexis Jeandet <[email protected]>
  • Loading branch information
jeandet committed Oct 22, 2020
1 parent 34bf49d commit c5446d9
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
9 changes: 8 additions & 1 deletion spwc/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,22 @@ def stats():
return _cache.stats()


def entries():
return _cache.keys()


class Cacheable(object):
def __init__(self, prefix, cache_instance=_cache, start_time_arg='start_time', stop_time_arg='stop_time',
version=None,
fragment_hours=lambda x: 1, cache_margins=1.2):
fragment_hours=lambda x: 1, cache_margins=1.2, leak_cache=False):
self.start_time_arg = start_time_arg
self.stop_time_arg = stop_time_arg
self.version = version
self.fragment_hours = fragment_hours
self.cache_margins = cache_margins
self.cache = cache_instance
self.prefix = prefix
self.leak_cache = leak_cache

def add_to_cache(self, variable: SpwcVariable, fragments, product, fragment_duration_hours, version):
if variable is not None:
Expand Down Expand Up @@ -122,4 +127,6 @@ def wrapped(wrapped_self, product, start_time, stop_time, **kwargs):
return result[dt_range.start_time:dt_range.stop_time]
return None

if self.leak_cache:
wrapped.cache = self.cache
return wrapped
21 changes: 15 additions & 6 deletions spwc/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ def __init__(self, data, version):


class Cache:
__slots__ = ['cache_file', '_data']
__slots__ = ['cache_file', '_data', '_hit', '_miss']

def __init__(self, cache_path: str = ""):
self._data = dc.FanoutCache(cache_path, shards=8, size_limit=int(float(cache_size.get())))
self._hit = 0
self._miss = 0
if self.version < cache_version:
self._data.clear()
self.version = cache_version
Expand All @@ -40,10 +42,9 @@ def disk_size(self):
return self._data.volume()

def stats(self):
s = self._data.stats()
return {
"hit": s[0],
"misses": s[1]
"hit": self._hit,
"misses": self._miss
}

def __len__(self):
Expand All @@ -53,12 +54,20 @@ def __del__(self):
pass

def keys(self):
return [item for item in self._data]
return list(self._data)

def __contains__(self, item):
return item in self._data
if item in self._data:
self._hit += 1
return True
self._miss += 1
return False

def __getitem__(self, key):
if key in self._data:
self._hit += 1
else:
self._miss += 1
return self._data[key]

def __setitem__(self, key, value):
Expand Down
1 change: 1 addition & 0 deletions spwc/common/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def view(self, range):
def __eq__(self, other: 'SpwcVariable') -> bool:
return self.meta == other.meta and \
self.columns == other.columns and \
len(self.time) == len(other.time) and \
np.all(self.time == other.time) and \
np.all(self.data == other.data)

Expand Down
20 changes: 15 additions & 5 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):
def version(self, product):
return self._version

@Cacheable(prefix="", cache_instance=cache, version=version)
@Cacheable(prefix="", cache_instance=cache, version=version, leak_cache=True)
def _make_data(self, product, start_time, stop_time):
index = np.array(
[(start_time + timedelta(minutes=delta)).timestamp() for delta in
Expand Down Expand Up @@ -56,21 +56,29 @@ def test_get_data_more_than_once(self):
tstart = datetime(2010, 6, 1, 12, 0, tzinfo=timezone.utc)
tend = datetime(2010, 6, 1, 15, 30, tzinfo=timezone.utc)
self.assertEqual(self._make_data_cntr, 0)
stats = self._make_data.cache.stats()
for _ in range(10):
var = self._make_data("test_get_data_more_than_once", tstart,
tend) # self.cache.get_data("test_get_data_more_than_once", DateTimeRange(tstart, tend), self._make_data)
self.assertEqual(self._make_data_cntr, 1)
new_stats = self._make_data.cache.stats()
self.assertGreater(new_stats["hit"], stats["hit"])
self.assertGreater(new_stats["misses"], stats["misses"])

def test_get_newer_version_data(self):
tstart = datetime(2010, 6, 1, 12, 0, tzinfo=timezone.utc)
tend = datetime(2010, 6, 1, 15, 30, tzinfo=timezone.utc)
self.assertEqual(self._make_data_cntr, 0)
stats = self._make_data.cache.stats()
for i in range(10):
self._version = f"{i}"
var = self._make_data("test_get_newer_version_data", tstart, tend)
# var = self.cache.get_data("test_get_newer_version_data", DateTimeRange(tstart, tend), self._make_data,
# version=f"{i}")
self.assertEqual(self._make_data_cntr, i + 1)
new_stats = self._make_data.cache.stats()
self.assertGreater(new_stats["hit"], stats["hit"])
self.assertGreater(new_stats["misses"], stats["misses"])

def test_get_same_version_data(self):
tstart = datetime(2010, 6, 1, 12, 0, tzinfo=timezone.utc)
Expand All @@ -83,6 +91,12 @@ def test_get_same_version_data(self):
# version="1.1.1")
self.assertEqual(self._make_data_cntr, 1)

def test_list_keys(self):
keys = self._make_data.cache.keys()
types = [type(key) for key in keys]
self.assertGreater(len(keys), 0)
self.assertListEqual(types, [str] * len(types))

def tearDown(self):
pass

Expand Down Expand Up @@ -230,11 +244,7 @@ def test_compare_version(self, lhs, rhs, op):


if __name__ == '__main__':
stats = cache.stats()
unittest.main()
new_stats = cache.stats()
assert (new_stats["hit"] > stats["hit"])
assert (new_stats["misses"] > stats["misses"])

del cache
shutil.rmtree(dirpath)

0 comments on commit c5446d9

Please sign in to comment.