Skip to content

Commit 4f65459

Browse files
committed
ensure tz on deserialize timestamp
this is needed to ensure we get a utc datetime when reading from sqlite or engines that don't support storing timestamps with a timezone
1 parent f11149b commit 4f65459

File tree

2 files changed

+113
-4
lines changed

2 files changed

+113
-4
lines changed

src/ell/types.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from datetime import datetime, timezone
1010
from typing import Any, List, Optional
11-
from sqlmodel import Field, SQLModel, Relationship, JSON, ARRAY, Column, Float
11+
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
12+
from sqlalchemy import TIMESTAMP, func
13+
import sqlalchemy.types as types
1214

1315
_lstr_generic = Union[lstr, str]
1416

@@ -42,6 +44,10 @@ class Message(dict, metaclass=DictSyncMeta):
4244
LMP = Union[OneTurn, MultiTurnLMP, ChatLMP]
4345
InvocableLM = Callable[..., _lstr_generic]
4446

47+
from datetime import timezone
48+
from sqlmodel import Field
49+
from typing import Optional
50+
4551

4652
def utc_now() -> datetime:
4753
"""
@@ -62,6 +68,16 @@ class SerializedLMPUses(SQLModel, table=True):
6268
lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP
6369

6470

71+
class UTCTimestamp(types.TypeDecorator[datetime]):
72+
impl = types.TIMESTAMP
73+
def process_result_value(self, value: datetime, dialect:Any):
74+
return value.replace(tzinfo=timezone.utc)
75+
76+
def UTCTimestampField(index:bool=False, **kwargs:Any):
77+
return Field(
78+
sa_column= Column(UTCTimestamp(timezone=True),index=index, **kwargs))
79+
80+
6581

6682
class SerializedLMP(SQLModel, table=True):
6783
"""
@@ -73,7 +89,12 @@ class SerializedLMP(SQLModel, table=True):
7389
name: str = Field(index=True) # Name of the LMP
7490
source: str # Source code or reference for the LMP
7591
dependencies: str # List of dependencies for the LMP, stored as a string
76-
created_at: datetime = Field(default_factory=utc_now, index=True) # Timestamp of when the LMP was created
92+
# Timestamp of when the LMP was created
93+
created_at: datetime = UTCTimestampField(
94+
index=True,
95+
default=func.now(),
96+
nullable=False
97+
)
7798
is_lm: bool # Boolean indicating if it is an LM (Language Model) or an LMP
7899
lm_kwargs: dict = Field(sa_column=Column(JSON)) # Additional keyword arguments for the LMP
79100

@@ -139,8 +160,8 @@ class Invocation(SQLModel, table=True):
139160
completion_tokens: Optional[int] = Field(default=None)
140161
state_cache_key: Optional[str] = Field(default=None)
141162

142-
143-
created_at: datetime = Field(default_factory=utc_now) # Timestamp of when the invocation was created
163+
# Timestamp of when the invocation was created
164+
created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
144165
invocation_kwargs: dict = Field(default_factory=dict, sa_column=Column(JSON)) # Additional keyword arguments for the invocation
145166

146167
# Relationships

tests/test_sql_store.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pytest
2+
from datetime import datetime, timezone
3+
from sqlmodel import Session, select
4+
from ell.stores.sql import SQLStore, SerializedLMP
5+
from sqlalchemy import Engine, create_engine
6+
7+
from ell.types import utc_now
8+
9+
@pytest.fixture
10+
def in_memory_db():
11+
return create_engine("sqlite:///:memory:")
12+
13+
@pytest.fixture
14+
def sql_store(in_memory_db: Engine) -> SQLStore:
15+
store = SQLStore("sqlite:///:memory:")
16+
store.engine = in_memory_db
17+
SerializedLMP.metadata.create_all(in_memory_db)
18+
return store
19+
20+
def test_write_lmp(sql_store: SQLStore):
21+
# Arrange
22+
lmp_id = "test_lmp_1"
23+
name = "Test LMP"
24+
source = "def test_function(): pass"
25+
dependencies = str(["dep1", "dep2"])
26+
is_lmp = True
27+
lm_kwargs = '{"param1": "value1"}'
28+
version_number = 1
29+
uses = {"used_lmp_1": {}, "used_lmp_2": {}}
30+
global_vars = {"global_var1": "value1"}
31+
free_vars = {"free_var1": "value2"}
32+
commit_message = "Initial commit"
33+
created_at = utc_now()
34+
assert created_at.tzinfo is not None
35+
36+
# Act
37+
sql_store.write_lmp(
38+
lmp_id=lmp_id,
39+
name=name,
40+
source=source,
41+
dependencies=dependencies,
42+
is_lmp=is_lmp,
43+
lm_kwargs=lm_kwargs,
44+
version_number=version_number,
45+
uses=uses,
46+
global_vars=global_vars,
47+
free_vars=free_vars,
48+
commit_message=commit_message,
49+
created_at=created_at
50+
)
51+
52+
# Assert
53+
with Session(sql_store.engine) as session:
54+
result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first()
55+
56+
assert result is not None
57+
assert result.lmp_id == lmp_id
58+
assert result.name == name
59+
assert result.source == source
60+
assert result.dependencies == str(dependencies)
61+
assert result.is_lm == is_lmp
62+
assert result.lm_kwargs == lm_kwargs
63+
assert result.version_number == version_number
64+
assert result.initial_global_vars == global_vars
65+
assert result.initial_free_vars == free_vars
66+
assert result.commit_message == commit_message
67+
# we want to assert created_at has timezone information
68+
assert result.created_at.tzinfo is not None
69+
70+
# Test that writing the same LMP again doesn't create a duplicate
71+
sql_store.write_lmp(
72+
lmp_id=lmp_id,
73+
name=name,
74+
source=source,
75+
dependencies=dependencies,
76+
is_lmp=is_lmp,
77+
lm_kwargs=lm_kwargs,
78+
version_number=version_number,
79+
uses=uses,
80+
global_vars=global_vars,
81+
free_vars=free_vars,
82+
commit_message=commit_message,
83+
created_at=created_at
84+
)
85+
86+
with Session(sql_store.engine) as session:
87+
count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
88+
assert count == 1

0 commit comments

Comments
 (0)