Skip to content

Commit

Permalink
Lift scope up to top level run
Browse files Browse the repository at this point in the history
  • Loading branch information
aron committed Nov 29, 2024
1 parent 23954bc commit f69d6dd
Showing 1 changed file with 25 additions and 30 deletions.
55 changes: 25 additions & 30 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def run(self) -> None:
tee=self._tee_output,
)

with redirector:
with scope(Scope(record_metric=self.record_metric)), redirector:
self._predictor = self._load_predictor()

# If setup didn't set the predictor, we're done here.
Expand Down Expand Up @@ -488,17 +488,16 @@ def _loop(
predict: Callable[..., Any],
redirector: StreamRedirector,
) -> None:
with scope(self._loop_scope()):
while True:
e = cast(Envelope, self._events.recv())
if isinstance(e.event, Cancel):
continue # Ignored in sync predictors.
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
self._predict(e.tag, e.event.payload, predict, redirector)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)
while True:
e = cast(Envelope, self._events.recv())
if isinstance(e.event, Cancel):
continue # Ignored in sync predictors.
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
self._predict(e.tag, e.event.payload, predict, redirector)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)

async def _aloop(
self,
Expand All @@ -511,24 +510,20 @@ async def _aloop(

task = None

with scope(self._loop_scope()):
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel) and task and self._cancelable:
task.cancel()
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = asyncio.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)
if task:
await task

def _loop_scope(self) -> Scope:
return Scope(record_metric=self.record_metric)
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel) and task and self._cancelable:
task.cancel()
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = asyncio.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)
if task:
await task

def _predict(
self,
Expand Down

0 comments on commit f69d6dd

Please sign in to comment.