Skip to content

Commit

Permalink
Add unit tests for async setup
Browse files Browse the repository at this point in the history
  • Loading branch information
Aron Carroll authored and aron committed Dec 18, 2024
1 parent cdc082b commit e3e2ce8
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 5 deletions.
3 changes: 0 additions & 3 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,6 @@ def _validate_predictor(
with self._handle_setup_error(redirector):
assert self._predictor

predict = get_predict(self._predictor)


# Async models require python >= 3.11 so we can use asyncio.TaskGroup
# We should check for this before getting to this point
if self._has_async_predictor and sys.version_info < (3, 11):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio


class Predictor:
async def setup(self) -> None:
self.loop = asyncio.get_running_loop()
Expand Down
12 changes: 12 additions & 0 deletions python/tests/server/fixtures/setup_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class Predictor:
async def download(self) -> None:
print("download complete!")

async def setup(self) -> None:
print("setup starting...")
await self.download()
print("setup complete!")

async def predict(self) -> str:
print("running prediction")
return "output"
9 changes: 9 additions & 0 deletions python/tests/server/fixtures/setup_async_with_sync_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Predictor:
async def download(self) -> None:
print("setup used asyncio.run! it's not very effective...")

async def setup(self) -> None:
await self.download()

def predict(self) -> str:
return "output"
59 changes: 58 additions & 1 deletion python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker):
@pytest.mark.skipif(
sys.version_info >= (3, 11), reason="Testing error message on python versions <3.11"
)
@uses_worker("simple_async", setup=False)
@uses_worker("simple_async", setup=False, is_async=True)
def test_async_predictor_on_python_3_10_or_older_raises_error(worker):
fut = worker.setup()
result = Result()
Expand All @@ -351,6 +351,57 @@ def test_async_predictor_on_python_3_10_or_older_raises_error(worker):
)


@uses_worker(
"setup_async", max_concurrency=1, min_python=(3, 11), is_async=True, setup=False
)
def test_setup_async(worker: Worker):
fut = worker.setup()
setup_result = Result()
setup_sid = worker.subscribe(setup_result.handle_event)

# with pytest.raises(FatalWorkerException):
fut.result()
worker.unsubscribe(setup_sid)

assert setup_result.stdout_lines == [
"setup starting...\n",
"download complete!\n",
"setup complete!\n",
]

predict_result = Result()
predict_sid = worker.subscribe(predict_result.handle_event, tag="p1")
worker.predict({}, tag="p1").result()

assert predict_result.done
assert predict_result.output == "output"
assert predict_result.stdout_lines == ["running prediction\n"]

worker.unsubscribe(predict_sid)


@uses_worker(
"setup_async_with_sync_predict",
max_concurrency=1,
min_python=(3, 11),
is_async=False,
setup=False,
)
def test_setup_async_with_sync_predict_raises_error(worker: Worker):
fut = worker.setup()
result = Result()
worker.subscribe(result.handle_event)

with pytest.raises(FatalWorkerException):
fut.result()
assert result.done
assert result.done.error
assert (
result.done.error_detail
== "Invalid predictor: to use an async setup method you must use an async predict method"
)


@uses_worker("simple", max_concurrency=5, setup=False)
def test_concurrency_with_sync_predictor_raises_error(worker):
fut = worker.setup()
Expand Down Expand Up @@ -555,6 +606,12 @@ def test_graceful_shutdown(worker):
assert fut.result() == Done()


@uses_worker("async_setup_uses_same_loop_as_predict", min_python=(3, 11), is_async=True)
def test_async_setup_uses_same_loop_as_predict(worker: Worker):
result = _process(worker, lambda: worker.predict({}), tag=None)
assert result, "Expected worker to return True to assert same event loop"


@frozen
class SetupState:
fut: "Future[Done]"
Expand Down

0 comments on commit e3e2ce8

Please sign in to comment.