@@ -350,6 +350,7 @@ def __init__(
350
350
predictor_ref : str ,
351
351
events : Connection ,
352
352
tee_output : bool = True ,
353
+ is_async : bool = False ,
353
354
) -> None :
354
355
self ._predictor_ref = predictor_ref
355
356
self ._predictor : Optional [BasePredictor ] = None
@@ -361,6 +362,7 @@ def __init__(
361
362
362
363
# for synchronous predictors only! async predictors use _tag_var instead
363
364
self ._sync_tag : Optional [str ] = None
365
+ self ._is_async = is_async
364
366
365
367
super ().__init__ ()
366
368
@@ -373,30 +375,34 @@ def run(self) -> None:
373
375
# Initially, we ignore SIGUSR1.
374
376
signal .signal (signal .SIGUSR1 , signal .SIG_IGN )
375
377
376
- self ._predictor = self ._load_predictor ()
377
-
378
- # If setup didn't set the predictor, we're done here.
379
- if not self ._predictor :
380
- return
381
-
382
- predict = get_predict (self ._predictor )
383
- if inspect .iscoroutinefunction (predict ) or inspect .isasyncgenfunction (predict ):
384
- async_redirector = AsyncStreamRedirector (
378
+ if self ._is_async :
379
+ redirector = AsyncStreamRedirector (
385
380
callback = self ._stream_write_hook ,
386
381
tee = self ._tee_output ,
387
382
)
388
- with async_redirector :
389
- self ._setup (async_redirector )
390
- asyncio .run (self ._aloop (predict , async_redirector ))
391
383
else :
392
- # We use SIGUSR1 to signal an interrupt for cancelation.
393
- signal .signal (signal .SIGUSR1 , self ._signal_handler )
394
-
395
384
redirector = StreamRedirector (
396
385
callback = self ._stream_write_hook ,
397
386
tee = self ._tee_output ,
398
387
)
399
- with redirector :
388
+
389
+ with redirector :
390
+ self ._predictor = self ._load_predictor ()
391
+
392
+ # If setup didn't set the predictor, we're done here.
393
+ if not self ._predictor :
394
+ return
395
+
396
+ predict = get_predict (self ._predictor )
397
+ if self ._is_async :
398
+ assert isinstance (redirector , AsyncStreamRedirector )
399
+ self ._setup (redirector )
400
+ asyncio .run (self ._aloop (predict , redirector ))
401
+ else :
402
+ # We use SIGUSR1 to signal an interrupt for cancelation.
403
+ signal .signal (signal .SIGUSR1 , self ._signal_handler )
404
+
405
+ assert isinstance (redirector , StreamRedirector )
400
406
self ._setup (redirector )
401
407
self ._loop (
402
408
predict ,
@@ -706,10 +712,15 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None:
706
712
707
713
708
714
def make_worker (
709
- predictor_ref : str , tee_output : bool = True , max_concurrency : int = 1
715
+ predictor_ref : str ,
716
+ tee_output : bool = True ,
717
+ max_concurrency : int = 1 ,
718
+ is_async : bool = False ,
710
719
) -> Worker :
711
720
parent_conn , child_conn = _spawn .Pipe ()
712
- child = _ChildWorker (predictor_ref , events = child_conn , tee_output = tee_output )
721
+ child = _ChildWorker (
722
+ predictor_ref , events = child_conn , tee_output = tee_output , is_async = is_async
723
+ )
713
724
parent = Worker (child = child , events = parent_conn , max_concurrency = max_concurrency )
714
725
return parent
715
726
0 commit comments