Skip to content

Commit

Permalink
add drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
kavigupta committed Feb 25, 2025
1 parent aaa8d49 commit baaef0d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 13 deletions.
48 changes: 39 additions & 9 deletions permacache/locked_shelf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import pickle
import shelve
import time
import uuid
Expand Down Expand Up @@ -161,7 +162,11 @@ class IndividualFileLockedStore:
"""

def __init__(
self, path, read_from_shelf_context_manager=None, multiprocess_safe=False
self,
path,
read_from_shelf_context_manager=None,
multiprocess_safe=False,
driver="json",
):
try:
os.makedirs(path)
Expand All @@ -172,34 +177,59 @@ def __init__(
self.cache = None
self.read_from_shelf_context_manager = read_from_shelf_context_manager
self.multi_process_safe = multiprocess_safe
assert driver in ("json", "pickle"), "driver must be json or pickle"
self.driver = driver

def _path_for_key(self, key):
if len(key) < 40 and all(c.isalnum() or c in "-_.,[](){} " for c in key):
return os.path.join(self.path, "." + key)
key = stable_hash(key)[:20]
key = "." + key
else:
key = stable_hash(key)[:20]
key = key + {"json": ".json", "pickle": ".pkl"}[self.driver]
return os.path.join(self.path, key)

def __getitem__(self, key):
with open(self._path_for_key(key), "r") as f:
result = json.load(f)
if self.driver == "json":
with open(self._path_for_key(key), "r") as f:
result = json.load(f)
elif self.driver == "pickle":
with open(self._path_for_key(key), "rb") as f:
result = pickle.load(f)
else:
raise ValueError(f"Unknown driver {self.driver}")
return result[key]

def __contains__(self, key):
return os.path.exists(self._path_for_key(key))

def __setitem__(self, key, value):
temporary_path = self._path_for_key(key) + "." + uuid.uuid4().hex[:10]
with open(temporary_path, "w") as f:
json.dump({key: value}, f)
if self.driver == "json":
out = json.dumps({key: value})
with open(temporary_path, "w") as f:
f.write(out)
elif self.driver == "pickle":
out = pickle.dumps({key: value})
with open(temporary_path, "wb") as f:
f.write(out)
else:
raise ValueError(f"Unknown driver {self.driver}")
os.replace(temporary_path, self._path_for_key(key))

def __delitem__(self, key):
os.remove(self._path_for_key(key))

def items(self):
for filename in os.listdir(self.path):
with open(os.path.join(self.path, filename), "r") as f:
yield from json.load(f).items()
if self.driver == "json":
with open(os.path.join(self.path, filename), "r") as f:
item = json.load(f)
elif self.driver == "pickle":
with open(os.path.join(self.path, filename), "rb") as f:
item = pickle.load(f)
else:
raise ValueError(f"Unknown driver {self.driver}")
yield from item.items()

def __enter__(self):
if self.multi_process_safe:
Expand Down
58 changes: 54 additions & 4 deletions tests/locked_shelf_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import os
import pickle
import random
import shutil
import unittest

import numpy as np

from permacache.hash import stable_hash
from permacache.locked_shelf import IndividualFileLockedStore, LockedShelf

Expand Down Expand Up @@ -71,18 +74,65 @@ def setUp(self):

def test_put_and_access(self):
super().test_put_and_access()
self.assertEqual(os.listdir("temp/tempshelf"), [".a"])
with open("temp/tempshelf/.a") as f:
self.assertEqual(os.listdir("temp/tempshelf"), [".a.json"])
with open("temp/tempshelf/.a.json") as f:
self.assertEqual(json.load(f), {"a": "b"})

def test_several_accesses(self):
super().test_several_accesses()
for p in os.listdir("temp/tempshelf"):
self.assertIn(p, [f".{x}" for x in range(100)])
self.assertIn(p, [f".{x}.json" for x in range(100)])

def test_large_key(self):
super().test_large_key()
h = stable_hash("a" * 100)[:20]
h = stable_hash("a" * 100)[:20] + ".json"
self.assertEqual(os.listdir("temp/tempshelf"), [h])
with open("temp/tempshelf/" + h) as f:
self.assertEqual(json.load(f), {"a" * 100: "b"})

def test_un_jsonable(self):
with self.assertRaises(TypeError):
with self.shelf as s:
s["a"] = np.array([1, 2, 3])

with self.shelf as s:
self.assertFalse("a" in s)
self.assertEqual(list(s.items()), [])


class IndividualFileLockedStoreTestPickle(LockedShelfTest):
def setUp(self):
self.shelf = IndividualFileLockedStore("temp/tempshelf", driver="pickle")

def test_put_and_access(self):
super().test_put_and_access()
self.assertEqual(os.listdir("temp/tempshelf"), [".a.pkl"])
with open("temp/tempshelf/.a.pkl", "rb") as f:
self.assertEqual(pickle.load(f), {"a": "b"})

def test_several_accesses(self):
super().test_several_accesses()
for p in os.listdir("temp/tempshelf"):
self.assertIn(p, [f".{x}.pkl" for x in range(100)])

def test_large_key(self):
super().test_large_key()
h = stable_hash("a" * 100)[:20] + ".pkl"
self.assertEqual(os.listdir("temp/tempshelf"), [h])
with open("temp/tempshelf/" + h, "rb") as f:
self.assertEqual(pickle.load(f), {"a" * 100: "b"})

def test_un_jsonable(self):
with self.shelf as s:
s["a"] = np.array([1, 2, 3])
with self.shelf as s:
self.assertEqual(type(s["a"]), np.ndarray)
self.assertEqual(s["a"].tolist(), [1, 2, 3])

def test_un_pickleable(self):
with self.assertRaises(Exception):
with self.shelf as s:
s["a"] = lambda x: x
with self.shelf as s:
self.assertFalse("a" in s)
self.assertEqual(list(s.items()), [])

0 comments on commit baaef0d

Please sign in to comment.