diff --git a/outlines/caching.py b/outlines/caching.py index 6fdda6214..3e58c7c16 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -126,7 +126,7 @@ def wrapper(*args, **kwargs): def __cache_key__(*args, **kwargs): """Make key for cache given function arguments.""" - return args_to_key(base, args, kwargs, typed, ignore) + return str(args_to_key(base, args, kwargs, typed, ignore)) wrapper.__cache_key__ = __cache_key__ # type: ignore wrapper.__memory__ = memory # type: ignore diff --git a/tests/test_cache.py b/tests/test_cache.py index eb4ec406e..4492ed58a 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import os import tempfile import unittest @@ -68,6 +69,33 @@ def f(x): assert len(store) == store_size + 1 +def test_get_cache_from_class_method(test_cache): + store = list() + + class DummyObject: + @classmethod + @test_cache + def dummy_function(cls, a): + store.append(a) + return a + + @dataclass + class DummyArg: + a: int + + dummy_object = DummyObject() + + a_1 = DummyArg(1) + + dummy_object.dummy_function(a_1) + assert len(store) == 1 + store_size = len(store) + + a_2 = DummyArg(1) + dummy_object.dummy_function(a_2) + assert len(store) == store_size + + def test_disable_cache(test_cache): """Make sure that we can disable the cache.""" import outlines