Skip to content

Commit b4e9124

Browse files
chore: Fix LoopConnectorError by enforcing background loop execution for Async Classes tests (#182)
* chore: Fix LoopConnectorError by enforcing background loop execution for Async Classes tests * linter fix * fix tests * run .create on background loop * fix tests * fix linter --------- Co-authored-by: Averi Kitsch <akitsch@google.com>
1 parent 84c6a32 commit b4e9124

8 files changed

Lines changed: 670 additions & 340 deletions

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cloud-sql-python-connector[asyncpg]==1.18.4
1+
cloud-sql-python-connector[asyncpg]==1.18.5
22
llama-index-core==0.14.4
33
pgvector==0.4.1
44
SQLAlchemy[asyncio]==2.0.43

tests/test_async_chat_store.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import asyncio
1515
import os
1616
import uuid
17-
from typing import Sequence
17+
from typing import Any, Coroutine, Sequence
1818

1919
import pytest
2020
import pytest_asyncio
@@ -28,18 +28,35 @@
2828
sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresChatStore. Use PostgresChatStore interface instead."
2929

3030

31+
# Helper to bridge the Main Test Loop and the Engine Background Loop
32+
async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any:
33+
"""Runs a coroutine on the engine's background loop."""
34+
if engine._loop:
35+
return await asyncio.wrap_future(
36+
asyncio.run_coroutine_threadsafe(coro, engine._loop)
37+
)
38+
return await coro
39+
40+
3141
async def aexecute(engine: PostgresEngine, query: str) -> None:
32-
async with engine._pool.connect() as conn:
33-
await conn.execute(text(query))
34-
await conn.commit()
42+
async def _impl():
43+
async with engine._pool.connect() as conn:
44+
await conn.execute(text(query))
45+
await conn.commit()
46+
47+
await run_on_background(engine, _impl())
3548

3649

3750
async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]:
38-
async with engine._pool.connect() as conn:
39-
result = await conn.execute(text(query))
40-
result_map = result.mappings()
41-
result_fetch = result_map.fetchall()
42-
return result_fetch
51+
async def _impl():
52+
async with engine._pool.connect() as conn:
53+
result = await conn.execute(text(query))
54+
result_map = result.mappings()
55+
result_fetch = result_map.fetchall()
56+
return result_fetch
57+
58+
result = await run_on_background(engine, _impl())
59+
return result
4360

4461

4562
def get_env_var(key: str, desc: str) -> str:
@@ -96,10 +113,15 @@ async def async_engine(
96113

97114
@pytest_asyncio.fixture(scope="class")
98115
async def chat_store(self, async_engine):
99-
await async_engine._ainit_chat_store_table(table_name=default_table_name_async)
100-
101-
chat_store = await AsyncPostgresChatStore.create(
102-
engine=async_engine, table_name=default_table_name_async
116+
await run_on_background(
117+
async_engine,
118+
async_engine._ainit_chat_store_table(table_name=default_table_name_async),
119+
)
120+
chat_store = await run_on_background(
121+
async_engine,
122+
AsyncPostgresChatStore.create(
123+
engine=async_engine, table_name=default_table_name_async
124+
),
103125
)
104126

105127
yield chat_store
@@ -117,21 +139,23 @@ async def test_async_add_message(self, async_engine, chat_store):
117139
key = "test_add_key"
118140

119141
message = ChatMessage(content="add_message_test", role="user")
120-
await chat_store.async_add_message(key, message=message)
142+
await run_on_background(
143+
async_engine, chat_store.async_add_message(key, message=message)
144+
)
121145

122146
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
123147
results = await afetch(async_engine, query)
124148
result = results[0]
125149
assert result["message"] == message.model_dump()
126150

127-
async def test_aset_and_aget_messages(self, chat_store):
151+
async def test_aset_and_aget_messages(self, async_engine, chat_store):
128152
message_1 = ChatMessage(content="First message", role="user")
129153
message_2 = ChatMessage(content="Second message", role="user")
130154
messages = [message_1, message_2]
131155
key = "test_set_and_get_key"
132-
await chat_store.aset_messages(key, messages)
156+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
133157

134-
results = await chat_store.aget_messages(key)
158+
results = await run_on_background(async_engine, chat_store.aget_messages(key))
135159

136160
assert len(results) == 2
137161
assert results[0].content == message_1.content
@@ -140,9 +164,9 @@ async def test_aset_and_aget_messages(self, chat_store):
140164
async def test_adelete_messages(self, async_engine, chat_store):
141165
messages = [ChatMessage(content="Message to delete", role="user")]
142166
key = "test_delete_key"
143-
await chat_store.aset_messages(key, messages)
167+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
144168

145-
await chat_store.adelete_messages(key)
169+
await run_on_background(async_engine, chat_store.adelete_messages(key))
146170
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
147171
results = await afetch(async_engine, query)
148172

@@ -153,9 +177,9 @@ async def test_adelete_message(self, async_engine, chat_store):
153177
message_2 = ChatMessage(content="Delete me", role="user")
154178
messages = [message_1, message_2]
155179
key = "test_delete_message_key"
156-
await chat_store.aset_messages(key, messages)
180+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
157181

158-
await chat_store.adelete_message(key, 1)
182+
await run_on_background(async_engine, chat_store.adelete_message(key, 1))
159183
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
160184
results = await afetch(async_engine, query)
161185

@@ -168,9 +192,9 @@ async def test_adelete_last_message(self, async_engine, chat_store):
168192
message_3 = ChatMessage(content="Message 3", role="user")
169193
messages = [message_1, message_2, message_3]
170194
key = "test_delete_last_message_key"
171-
await chat_store.aset_messages(key, messages)
195+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
172196

173-
await chat_store.adelete_last_message(key)
197+
await run_on_background(async_engine, chat_store.adelete_last_message(key))
174198
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
175199
results = await afetch(async_engine, query)
176200

@@ -183,18 +207,22 @@ async def test_aget_keys(self, async_engine, chat_store):
183207
message_2 = [ChatMessage(content="Second message", role="user")]
184208
key_1 = "key1"
185209
key_2 = "key2"
186-
await chat_store.aset_messages(key_1, message_1)
187-
await chat_store.aset_messages(key_2, message_2)
210+
await run_on_background(
211+
async_engine, chat_store.aset_messages(key_1, message_1)
212+
)
213+
await run_on_background(
214+
async_engine, chat_store.aset_messages(key_2, message_2)
215+
)
188216

189-
keys = await chat_store.aget_keys()
217+
keys = await run_on_background(async_engine, chat_store.aget_keys())
190218

191219
assert key_1 in keys
192220
assert key_2 in keys
193221

194222
async def test_set_exisiting_key(self, async_engine, chat_store):
195223
message_1 = [ChatMessage(content="First message", role="user")]
196224
key = "test_set_exisiting_key"
197-
await chat_store.aset_messages(key, message_1)
225+
await run_on_background(async_engine, chat_store.aset_messages(key, message_1))
198226

199227
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
200228
results = await afetch(async_engine, query)
@@ -207,7 +235,7 @@ async def test_set_exisiting_key(self, async_engine, chat_store):
207235
message_3 = ChatMessage(content="Third message", role="user")
208236
messages = [message_2, message_3]
209237

210-
await chat_store.aset_messages(key, messages)
238+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
211239

212240
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
213241
results = await afetch(async_engine, query)

0 commit comments

Comments
 (0)