Skip to content

Commit

Permalink
prevent clearing uncommitted data by default
Browse files Browse the repository at this point in the history
  • Loading branch information
amakelov committed Sep 3, 2024
1 parent e6422b1 commit 1311267
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
7 changes: 7 additions & 0 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def vacuum(self):
############################################################################
### managing the caches
############################################################################
def clear_cache(self, allow_uncommitted: False):
self.atoms.clear(allow_uncommited=allow_uncommitted)
self.shapes.clear(allow_uncommited=allow_uncommitted)
self.ops.clear(allow_uncommited=allow_uncommitted)
self.calls.clear(allow_uncommited=allow_uncommitted)
print("Cleared all caches.")

def cache_info(self) -> str:
"""
Display information about the contents of the cache in a pretty table.
Expand Down
16 changes: 14 additions & 2 deletions mandala/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def values(self, conn: Optional[sqlite3.Connection] = None) -> List[Any]:
class CachedDictStorage(DictStorage):
def __init__(self, persistent: DictStorage):
self.persistent = persistent
# keep a cache of the values for faster lookups
self.cache: Dict[str, Any] = {}
# keep track of keys that have been added but not yet persisted
self.dirty_keys: Set[str] = set()

def load_all(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -271,7 +273,13 @@ def commit(self, conn: Optional[sqlite3.Connection] = None) -> None:
self.persistent.set(key, self.cache[key], conn=conn)
self.dirty_keys.clear()

def clear(self) -> None:
def clear(self, allow_uncommited: bool = False) -> None:
if len(self.dirty_keys) > 0 and not allow_uncommited:
# we add this as a precaution to avoid data loss. Otherwise, it's
# easy to shoot yourself in the foot by calling `clear()` before
# `commit()`
msg = "Cannot clear cache with uncommitted changes; call `commit()` first, or use `allow_uncommited=True`"
raise ValueError(msg)
self.cache.clear()
self.dirty_keys.clear()

Expand Down Expand Up @@ -739,6 +747,10 @@ def commit(self, conn: Optional[sqlite3.Connection] = None):
self.persistent.save(self.cache.get_data(hid), conn=conn)
self.dirty_hids.clear()

def clear(self):
def clear(self, allow_uncommited: bool = False):
if len(self.dirty_hids) > 0 and not allow_uncommited:
# see `CachedDictStorage.clear` for an explanation
msg = "Cannot clear cache with uncommitted changes; call `commit()` first, or use `allow_uncommited=True`"
raise ValueError(msg)
self.cache = InMemCallStorage()
self.dirty_hids.clear()
25 changes: 24 additions & 1 deletion mandala/tests/test_memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,27 @@ def inc(x, irrelevant):
inc(23, 1)

df = storage.cf(inc).df()
assert len(df) == 1
assert len(df) == 1


def test_clear_uncommitted():
storage = Storage()

@op
def inc(x):
return x + 1

with storage:
for i in range(10):
inc(i)
# attempt to clear the atoms cache without having committed; this should
# fail by default
try:
storage.atoms.clear()
assert False
except ValueError:
pass

# now clear the atoms cache after committing
storage.commit()
storage.atoms.clear()

0 comments on commit 1311267

Please sign in to comment.