Skip to content

Commit e6ce85f

Browse files
committed
stores
1 parent 6156aec commit e6ce85f

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

src/ell/store.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from contextlib import contextmanager
3-
from typing import Any, Optional, Dict, List, Set
3+
from typing import Any, Optional, Dict, List, Set, Union
44
from ell.lstr import lstr
55
from ell.types import InvocableLM
66

@@ -32,7 +32,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
3232
pass
3333

3434
@abstractmethod
35-
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: lstr | List[lstr], invocation_kwargs: Dict[str, Any],
35+
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any],
3636
created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None,
3737
completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,
3838
state_cache_key: Optional[str] = None,
@@ -118,7 +118,7 @@ def get_latest_lmps(self) -> List[Dict[str, Any]]:
118118

119119

120120
@contextmanager
121-
def freeze(self, *lmps : InvocableLM):
121+
def freeze(self, *lmps: InvocableLM):
122122
"""
123123
A context manager for caching operations using a particular store.
124124
@@ -138,6 +138,7 @@ def freeze(self, *lmps : InvocableLM):
138138
finally:
139139
# TODO: Implement cache storage logic here
140140
for lmp in lmps:
141-
lmp.__ell_use_cache__ = old_cache_values.get(lmp, None)
142-
143-
141+
if lmp in old_cache_values:
142+
setattr(lmp, '__ell_use_cache__', old_cache_values[lmp])
143+
else:
144+
delattr(lmp, '__ell_use_cache__')

src/ell/stores/sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import json
33
import os
4-
from typing import Any, Optional, Dict, List, Set
4+
from typing import Any, Optional, Dict, List, Set, Union
55
from sqlmodel import Session, SQLModel, create_engine, select
66
import ell.store
77
import cattrs
@@ -57,7 +57,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
5757
session.commit()
5858
return None
5959

60-
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: lstr | List[lstr], invocation_kwargs: Dict[str, Any],
60+
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any],
6161
global_vars: Dict[str, Any],
6262
free_vars: Dict[str, Any], created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None,
6363
completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,

0 commit comments

Comments
 (0)