Skip to content

Commit 758bf76

Browse files
committed
Inspect predictor to determine if predict/train is async
1 parent 1722d8c commit 758bf76

File tree

4 files changed

+51
-23
lines changed

4 files changed

+51
-23
lines changed

python/cog/config.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import os
23
import sys
34
import uuid
@@ -16,7 +17,9 @@
1617
from .predictor import (
1718
get_input_type,
1819
get_output_type,
20+
get_predict,
1921
get_predictor,
22+
get_train,
2023
get_training_input_type,
2124
get_training_output_type,
2225
load_full_predictor_from_file,
@@ -152,16 +155,26 @@ def get_predictor_ref(self, mode: Mode) -> str:
152155

153156
def get_predictor_types(
154157
self, mode: Mode
155-
) -> Tuple[Type[BaseInput], Type[BaseModel]]:
158+
) -> Tuple[Type[BaseInput], Type[BaseModel], bool]:
156159
"""Find the input and output types of a predictor."""
157160
predictor_ref = self.get_predictor_ref(mode=mode)
158161
predictor = self._load_predictor_for_types(
159162
predictor_ref, _method_name_from_mode(mode=mode), mode
160163
)
164+
165+
def is_async(fn) -> bool:
166+
return inspect.iscoroutinefunction(fn) or inspect.isasyncgenfunction(fn)
167+
161168
if mode == Mode.PREDICT:
162-
return get_input_type(predictor), get_output_type(predictor)
169+
return (
170+
get_input_type(predictor),
171+
get_output_type(predictor),
172+
is_async(get_predict(predictor)),
173+
)
163174
elif mode == Mode.TRAIN:
164-
return get_training_input_type(predictor), get_training_output_type(
165-
predictor
175+
return (
176+
get_training_input_type(predictor),
177+
get_training_output_type(predictor),
178+
is_async(get_train(predictor)),
166179
)
167180
raise ValueError(f"Mode {mode} not found for generating input/output types.")

python/cog/server/http.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ async def start_shutdown() -> Any:
156156
return JSONResponse({}, status_code=200)
157157

158158
try:
159-
InputType, OutputType = cog_config.get_predictor_types(mode=Mode.PREDICT)
159+
InputType, OutputType, is_async = cog_config.get_predictor_types(
160+
mode=Mode.PREDICT
161+
)
160162
except Exception: # pylint: disable=broad-exception-caught
161163
msg = "Error while loading predictor:\n\n" + traceback.format_exc()
162164
add_setup_failed_routes(app, started_at, msg)

python/cog/server/worker.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350
predictor_ref: str,
351351
events: Connection,
352352
tee_output: bool = True,
353+
is_async: bool = False,
353354
) -> None:
354355
self._predictor_ref = predictor_ref
355356
self._predictor: Optional[BasePredictor] = None
@@ -361,6 +362,7 @@ def __init__(
361362

362363
# for synchronous predictors only! async predictors use _tag_var instead
363364
self._sync_tag: Optional[str] = None
365+
self._is_async = is_async
364366

365367
super().__init__()
366368

@@ -373,30 +375,34 @@ def run(self) -> None:
373375
# Initially, we ignore SIGUSR1.
374376
signal.signal(signal.SIGUSR1, signal.SIG_IGN)
375377

376-
self._predictor = self._load_predictor()
377-
378-
# If setup didn't set the predictor, we're done here.
379-
if not self._predictor:
380-
return
381-
382-
predict = get_predict(self._predictor)
383-
if inspect.iscoroutinefunction(predict) or inspect.isasyncgenfunction(predict):
384-
async_redirector = AsyncStreamRedirector(
378+
if self._is_async:
379+
redirector = AsyncStreamRedirector(
385380
callback=self._stream_write_hook,
386381
tee=self._tee_output,
387382
)
388-
with async_redirector:
389-
self._setup(async_redirector)
390-
asyncio.run(self._aloop(predict, async_redirector))
391383
else:
392-
# We use SIGUSR1 to signal an interrupt for cancelation.
393-
signal.signal(signal.SIGUSR1, self._signal_handler)
394-
395384
redirector = StreamRedirector(
396385
callback=self._stream_write_hook,
397386
tee=self._tee_output,
398387
)
399-
with redirector:
388+
389+
with redirector:
390+
self._predictor = self._load_predictor()
391+
392+
# If setup didn't set the predictor, we're done here.
393+
if not self._predictor:
394+
return
395+
396+
predict = get_predict(self._predictor)
397+
if self._is_async:
398+
assert isinstance(redirector, AsyncStreamRedirector)
399+
self._setup(redirector)
400+
asyncio.run(self._aloop(predict, redirector))
401+
else:
402+
# We use SIGUSR1 to signal an interrupt for cancelation.
403+
signal.signal(signal.SIGUSR1, self._signal_handler)
404+
405+
assert isinstance(redirector, StreamRedirector)
400406
self._setup(redirector)
401407
self._loop(
402408
predict,
@@ -706,10 +712,15 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None:
706712

707713

708714
def make_worker(
709-
predictor_ref: str, tee_output: bool = True, max_concurrency: int = 1
715+
predictor_ref: str,
716+
tee_output: bool = True,
717+
max_concurrency: int = 1,
718+
is_async: bool = False,
710719
) -> Worker:
711720
parent_conn, child_conn = _spawn.Pipe()
712-
child = _ChildWorker(predictor_ref, events=child_conn, tee_output=tee_output)
721+
child = _ChildWorker(
722+
predictor_ref, events=child_conn, tee_output=tee_output, is_async=is_async
723+
)
713724
parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency)
714725
return parent
715726

python/tests/server/fixtures/logging.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
# a bunch of stuff at import time.
1010
libc.puts(b"writing some stuff from C at import time")
1111
libc.fflush(None)
12+
1213
sys.stdout.write("writing to stdout at import time\n")
1314
sys.stderr.write("writing to stderr at import time\n")
1415

1516

1617
class Predictor:
1718
def setup(self):
1819
print("setting up predictor")
20+
print("writing to stderr at setup time", file=sys.stderr)
1921
self.foo = "foo"
2022

2123
def predict(self) -> str:

0 commit comments

Comments
 (0)