Skip to content

Commit

Permalink
handle dataclasses (#7)
Browse files Browse the repository at this point in the history
* handle dataclasses

* bump version, add test
  • Loading branch information
kavigupta authored Dec 26, 2024
1 parent 08a1eb4 commit bc9ad3f
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 7 deletions.
2 changes: 1 addition & 1 deletion permacache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .cache import permacache
from .cache_miss_error import CacheMissError, error_on_miss_global
from .dict_function import drop_if, drop_if_equal
from .hash import stable_hash, stringify
from .hash import migrated_attrs, stable_hash, stringify
from .swap_unpickler import renamed_symbol_unpickler, swap_unpickler_context_manager
17 changes: 17 additions & 0 deletions permacache/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def default(self, o):
typename = type(o).__name__
o = {a.name: getattr(o, a.name) for a in o.__attrs_attrs__}
o[".attr.__name__"] = typename
if hasattr(o, "__dataclass_fields__"):
typename = type(o).__name__
is_migrated_attrs = getattr(o, "_migrated_attrs", False)
o = {a: getattr(o, a) for a in o.__dataclass_fields__}
if is_migrated_attrs:
o[".attr.__name__"] = typename
else:
o[".dataclass.__name__"] = typename
if isinstance(o, SimpleNamespace):
o = o.__dict__
o[".builtin.__name__"] = "types.SimpleNamespace"
Expand Down Expand Up @@ -120,3 +128,12 @@ def stable_hash(obj, *, fast_bytes=True):
return hashlib.sha256(
stringify(obj, fast_bytes=fast_bytes).encode("utf-8")
).hexdigest()


def migrated_attrs(cls):
"""
Decorator to mark a class as having been migrated from attrs to dataclasses.
"""
# pylint: disable=protected-access
cls._migrated_attrs = True
return cls
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="permacache",
version="3.8.0",
version="3.9.0",
author="Kavi Gupta",
author_email="[email protected]",
description="Permanant cache.",
Expand Down
24 changes: 19 additions & 5 deletions tests/stringify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from permacache import stable_hash, stringify
from tests.test_module.c import A, B, C, D

NUMPY_VERSION = np.version.version.split(".", maxsplit=1)[0]

Expand Down Expand Up @@ -36,14 +37,27 @@ def test_stringify_json(self):
self.assertEqual('{"a": 3}', stringify({"a": 3}))

def test_stringify_attrs(self):
import attr
self.assertEqual(
json.dumps({".attr.__name__": "A", "x": 1, "y": "hello", "z": 3.2}),
stringify(A(1, "hello", 3.2)),
)

@attr.s
class X:
y = attr.ib()
def test_stringify_attr_dataclass(self):
self.assertEqual(
json.dumps({".attr.__name__": "B", "x": 1, "y": "hello", "z": 3.2}),
stringify(B(1, "hello", 3.2)),
)

def test_stringify_dataclass(self):
self.assertEqual(
json.dumps({".dataclass.__name__": "C", "x": 1, "y": "hello", "z": 3.2}),
stringify(C(1, "hello", 3.2)),
)

def test_migrated_dataclass(self):
self.assertEqual(
json.dumps({".attr.__name__": "X", "y": "hello"}), stringify(X("hello"))
json.dumps({".attr.__name__": "D", "x": 1, "y": "hello", "z": 3.2}),
stringify(D(1, "hello", 3.2)),
)

def test_stringify_numpy(self):
Expand Down
34 changes: 34 additions & 0 deletions tests/test_module/c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass

import attr

from permacache import migrated_attrs


@attr.s
class A:
x = attr.ib()
y = attr.ib()
z = attr.ib()


@attr.dataclass
class B:
x: int
y: str
z: float


@dataclass
class C:
x: int
y: str
z: float


@dataclass
@migrated_attrs
class D:
x: int
y: str
z: float

0 comments on commit bc9ad3f

Please sign in to comment.