Skip to content

Commit

Permalink
Add @use_worker_configs and update test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
aron committed Nov 29, 2024
1 parent f69d6dd commit bb9bbcc
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 25 deletions.
4 changes: 3 additions & 1 deletion python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ async def start_shutdown() -> Any:
add_setup_failed_routes(app, started_at, msg)
return app

worker = make_worker(predictor_ref=cog_config.get_predictor_ref(mode=mode))
worker = make_worker(
predictor_ref=cog_config.get_predictor_ref(mode=mode), is_async=is_async
)
runner = PredictionRunner(worker=worker)

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
Expand Down
6 changes: 4 additions & 2 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,10 @@ class _ChildWorker(_spawn.Process): # type: ignore
def __init__(
self,
predictor_ref: str,
*,
is_async: bool,
events: Connection,
tee_output: bool = True,
is_async: bool = False,
) -> None:
self._predictor_ref = predictor_ref
self._predictor: Optional[BasePredictor] = None
Expand Down Expand Up @@ -708,9 +709,10 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None:

def make_worker(
predictor_ref: str,
*,
is_async: bool,
tee_output: bool = True,
max_concurrency: int = 1,
is_async: bool = False,
) -> Worker:
parent_conn, child_conn = _spawn.Pipe()
child = _ChildWorker(
Expand Down
27 changes: 18 additions & 9 deletions python/tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
import time
from contextlib import ExitStack
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Sequence
from unittest import mock

import pytest
Expand All @@ -24,6 +24,7 @@ class AppConfig:
@define
class WorkerConfig:
fixture_name: str
is_async: bool = False
setup: bool = True
max_concurrency: int = 1

Expand Down Expand Up @@ -71,7 +72,7 @@ def uses_predictor_with_client_options(name, **options):
)


def uses_worker(name_or_names, setup=True, max_concurrency=1):
def uses_worker(name_or_names, setup=True, max_concurrency=1, is_async=False):
"""
Decorator for tests that require a Worker instance. `name_or_names` can be
a single fixture name, or a sequence (list, tuple) of fixture names. If
Expand All @@ -80,18 +81,25 @@ def uses_worker(name_or_names, setup=True, max_concurrency=1):
If `setup` is True (the default) setup will be run before the test runs.
"""
if isinstance(name_or_names, (tuple, list)):
values = (
WorkerConfig(fixture_name=n, setup=setup, max_concurrency=max_concurrency)
values = [
WorkerConfig(fixture_name=n, setup=setup, max_concurrency=max_concurrency, is_async=is_async)
for n in name_or_names
)
]
else:
values = (
values = [
WorkerConfig(
fixture_name=name_or_names, setup=setup, max_concurrency=max_concurrency
fixture_name=name_or_names, setup=setup, max_concurrency=max_concurrency, is_async
),
)
return pytest.mark.parametrize("worker", values, indirect=True)
]
return uses_worker_configs(values)


def uses_worker_configs(values: Sequence[WorkerConfig]):
"""
Decorator for tests that require a Worker instance. `configs` can be
a sequence of `WorkerConfig` instances.
"""
return pytest.mark.parametrize("worker", values, indirect=True)

def make_client(
fixture_name: str,
Expand Down Expand Up @@ -153,6 +161,7 @@ def worker(request):
ref = _fixture_path(request.param.fixture_name)
w = make_worker(
predictor_ref=ref,
is_async=request.param.is_async,
tee_output=False,
max_concurrency=request.param.max_concurrency,
)
Expand Down
4 changes: 2 additions & 2 deletions python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_prediction_runner_predict_cancelation_multiple_predictions():


def test_prediction_runner_setup_e2e():
w = make_worker(predictor_ref=_fixture_path("sleep"))
w = make_worker(predictor_ref=_fixture_path("sleep"), is_async=False)
r = PredictionRunner(worker=w)

try:
Expand All @@ -316,7 +316,7 @@ def test_prediction_runner_setup_e2e():


def test_prediction_runner_predict_e2e():
w = make_worker(predictor_ref=_fixture_path("sleep"))
w = make_worker(predictor_ref=_fixture_path("sleep"), is_async=False)
r = PredictionRunner(worker=w)

try:
Expand Down
38 changes: 27 additions & 11 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from cog.server.exceptions import FatalWorkerException, InvalidStateException
from cog.server.worker import Worker, _PublicEventType

from .conftest import WorkerConfig, uses_worker
from .conftest import WorkerConfig, uses_worker, uses_worker_configs

if TYPE_CHECKING:
from concurrent.futures import Future
Expand Down Expand Up @@ -76,7 +76,7 @@
},
),
(
WorkerConfig("record_metric_async"),
WorkerConfig("record_metric_async", is_async=True),
{"name": ST_NAMES},
{
"foo": 123,
Expand All @@ -91,7 +91,7 @@
lambda x: f"hello, {x['name']}",
),
(
WorkerConfig("hello_world_async"),
WorkerConfig("hello_world_async", is_async=True),
{"name": ST_NAMES},
lambda x: f"hello, {x['name']}",
),
Expand All @@ -114,7 +114,7 @@
"writing to stderr at import time\n",
),
(
WorkerConfig("logging_async", setup=False),
WorkerConfig("logging_async", setup=False, is_async=True),
("writing to stdout at import time\n" "setting up predictor\n"),
"writing to stderr at import time\n",
),
Expand All @@ -127,7 +127,7 @@
("WARNING:root:writing log message\n" "writing to stderr\n"),
),
(
WorkerConfig("logging_async"),
WorkerConfig("logging_async", is_async=True),
("writing with print\n"),
("WARNING:root:writing log message\n" "writing to stderr\n"),
),
Expand Down Expand Up @@ -238,7 +238,9 @@ def test_no_exceptions_from_recoverable_failures(worker):


# TODO test this works with errors and cancelations and the like
@uses_worker(["simple", "simple_async"])
@uses_worker_configs(
[WorkerConfig("simple"), WorkerConfig("simple_async", is_async=True)]
)
def test_can_subscribe_for_a_specific_tag(worker):
tag = "123"

Expand All @@ -260,7 +262,7 @@ def test_can_subscribe_for_a_specific_tag(worker):
worker.unsubscribe(subid)


@uses_worker("sleep_async", max_concurrency=5)
@uses_worker("sleep_async", max_concurrency=5, is_async=True)
def test_can_run_predictions_concurrently_on_async_predictor(worker):
subids = []

Expand Down Expand Up @@ -383,7 +385,12 @@ def test_predict_logging(worker, expected_stdout, expected_stderr):
assert result.stderr == expected_stderr


@uses_worker(["sleep", "sleep_async"], setup=False)
@uses_worker_configs(
[
WorkerConfig("sleep", setup=False),
WorkerConfig("sleep_async", setup=False, is_async=True),
]
)
def test_cancel_is_safe(worker):
"""
Calls to cancel at any time should not result in unexpected things
Expand Down Expand Up @@ -417,7 +424,12 @@ def test_cancel_is_safe(worker):
assert result2.output == "done in 0.1 seconds"


@uses_worker(["sleep", "sleep_async"], setup=False)
@uses_worker_configs(
[
WorkerConfig("sleep", setup=False),
WorkerConfig("sleep_async", setup=False, is_async=True),
]
)
def test_cancel_idempotency(worker):
"""
Multiple calls to cancel within the same prediction, while not necessary or
Expand Down Expand Up @@ -449,7 +461,9 @@ def cancel_a_bunch(_):
assert result2.output == "done in 0.1 seconds"


@uses_worker(["sleep", "sleep_async"])
@uses_worker_configs(
[WorkerConfig("sleep"), WorkerConfig("sleep_async", is_async=True)]
)
def test_cancel_multiple_predictions(worker):
"""
Multiple predictions cancelled in a row shouldn't be a problem. This test
Expand All @@ -467,7 +481,9 @@ def test_cancel_multiple_predictions(worker):
assert not worker.predict({"sleep": 0}).result().canceled


@uses_worker(["sleep", "sleep_async"])
@uses_worker_configs(
[WorkerConfig("sleep"), WorkerConfig("sleep_async", is_async=True)]
)
def test_graceful_shutdown(worker):
"""
On shutdown, the worker should finish running the current prediction, and
Expand Down

0 comments on commit bb9bbcc

Please sign in to comment.