1
1
from abc import ABC , abstractmethod
2
2
from contextlib import contextmanager
3
- from typing import Any , Optional , Dict , List , Set
3
+ from typing import Any , Optional , Dict , List , Set , Union
4
4
from ell .lstr import lstr
5
5
from ell .types import InvocableLM
6
6
@@ -32,7 +32,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
32
32
pass
33
33
34
34
@abstractmethod
35
- def write_invocation (self , id : str , lmp_id : str , args : str , kwargs : str , result : lstr | List [lstr ], invocation_kwargs : Dict [str , Any ],
35
+ def write_invocation (self , id : str , lmp_id : str , args : str , kwargs : str , result : Union [ lstr , List [lstr ] ], invocation_kwargs : Dict [str , Any ],
36
36
created_at : Optional [float ], consumes : Set [str ], prompt_tokens : Optional [int ] = None ,
37
37
completion_tokens : Optional [int ] = None , latency_ms : Optional [float ] = None ,
38
38
state_cache_key : Optional [str ] = None ,
@@ -118,7 +118,7 @@ def get_latest_lmps(self) -> List[Dict[str, Any]]:
118
118
119
119
120
120
@contextmanager
121
- def freeze (self , * lmps : InvocableLM ):
121
+ def freeze (self , * lmps : InvocableLM ):
122
122
"""
123
123
A context manager for caching operations using a particular store.
124
124
@@ -138,6 +138,7 @@ def freeze(self, *lmps : InvocableLM):
138
138
finally :
139
139
# TODO: Implement cache storage logic here
140
140
for lmp in lmps :
141
- lmp .__ell_use_cache__ = old_cache_values .get (lmp , None )
142
-
143
-
141
+ if lmp in old_cache_values :
142
+ setattr (lmp , '__ell_use_cache__' , old_cache_values [lmp ])
143
+ else :
144
+ delattr (lmp , '__ell_use_cache__' )
0 commit comments