Skip to content

Commit

Permalink
Download inputs off the event consumer loop
Browse files Browse the repository at this point in the history
Currently, we download inputs (via URLPath's convert() method) in
_prepare_payload, in the main Worker event consumer loop.  This is bad, because
downloading inputs is a blocking operation and we cannot process other events
while downloads are in progress.

This is a substantial restructuring:

- the event consumer loop now only processes events from the child; prediction
  requests and cancel requests from the http server are now processed directly.
- there are new concurrent.future.ThreadPoolExecutor instances for preparing
  predictions and downloading inputs.
  - as part of this, this commit also supports downloading inputs concurrently
    rather than serially.  I have hardcoded a maximum of 8 concurrent downloads
    based on the ThreadPoolExecutor size.  I picked this number arbitrarily.
- when the web server receives a request to start a prediction, we submit a task
  on the prediction start executor, which in turn submits downloads on the input
  download executor. Once all inputs are downloaded, we send the prediction
  request to the child.
- cancel requests are handled similarly, but as there is no input download this
  is much simpler.
- all calls to self._events.send() now happen within the http server async event
  loop, not the Worker ThreadPoolExecutor.  (This is maybe not strictly correct,
  as we're calling blocking i/o from an async event loop, but it's a local
  Connection object and it's a tiny amount of data)
  • Loading branch information
philandstuff committed Dec 20, 2024
1 parent 392819d commit f2ba468
Showing 1 changed file with 69 additions and 62 deletions.
131 changes: 69 additions & 62 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import inspect
import multiprocessing
import os
import select
import signal
import sys
import threading
Expand All @@ -13,6 +12,7 @@
import uuid
import warnings
import weakref
from concurrent import futures
from concurrent.futures import Future, ThreadPoolExecutor
from enum import Enum, auto, unique
from multiprocessing.connection import Connection
Expand Down Expand Up @@ -126,18 +126,16 @@ def __init__(
self._predictions_lock = threading.Lock()
self._predictions_in_flight: Dict[Optional[str], PredictionState] = {}

recv_conn, send_conn = _spawn.Pipe(duplex=False)
self._request_send_conn = send_conn
self._request_recv_conn = recv_conn

self._pool = ThreadPoolExecutor(max_workers=1)
self._event_consumer_pool = ThreadPoolExecutor(max_workers=1)
self._prediction_start_pool = ThreadPoolExecutor(max_workers=max_concurrency)
self._input_download_pool = ThreadPoolExecutor(max_workers=8)
self._event_consumer = None

def setup(self) -> "Future[Done]":
self._assert_state(WorkerState.NEW)
self._state = WorkerState.STARTING
self._child.start()
self._event_consumer = self._pool.submit(self._consume_events)
self._event_consumer = self._event_consumer_pool.submit(self._consume_events)
return self._setup_result

def predict(
Expand All @@ -163,9 +161,57 @@ def predict(
self._assert_state(WorkerState.READY)
result = Future()
self._predictions_in_flight[tag] = PredictionState(tag, payload, result)
self._request_send_conn.send(PredictionRequest(tag))

self._prediction_start_pool.submit(self._start_prediction(tag, payload))
return result

def _start_prediction(
self, tag: Optional[str], payload: Dict[str, Any]
) -> Callable[[], None]:
def start_prediction() -> None:
try:
to_await = []
futs = {}
# Prepare payload asynchronously (download URLPath objects)
for k, v in payload.items():
# Check if v is an instance of URLPath
if isinstance(v, URLPath):
futs[k] = self._input_download_pool.submit(v.convert)
to_await.append(futs[k])
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(
isinstance(item, URLPath) for item in v
):
futs[k] = [
self._input_download_pool.submit(item.convert) for item in v
]
to_await += futs[k]
futures.wait(to_await, return_when=futures.FIRST_EXCEPTION)
for k, v in futs.items():
if isinstance(v, list):
payload[k] = []
for fut in v:
# the future may not be done if and only if another
# future finished with an exception
if fut.done():
payload[k].append(fut.result())
elif isinstance(v, Future):
if v.done():
payload[k] = v.result()
# send the prediction to the child to start
self._events.send(
Envelope(
event=PredictionInput(payload=payload),
tag=tag,
)
)
except Exception as e:
done = Done(error=True, error_detail=str(e))
self._publish(Envelope(done, tag))
self._complete_prediction(done, tag)

return start_prediction

def subscribe(
self,
subscriber: Callable[[_PublicEventType], None],
Expand Down Expand Up @@ -195,7 +241,7 @@ def shutdown(self, timeout: Optional[float] = None) -> None:
if self._event_consumer:
self._event_consumer.result(timeout=timeout)

self._pool.shutdown()
self._event_consumer_pool.shutdown()

def terminate(self) -> None:
"""
Expand All @@ -209,10 +255,15 @@ def terminate(self) -> None:
self._child.terminate()
self._child.join()

self._pool.shutdown(wait=False)
self._event_consumer_pool.shutdown(wait=False)

def cancel(self, tag: Optional[str] = None) -> None:
self._request_send_conn.send(CancelRequest(tag))
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel_signal()
self._events.send(Envelope(event=Cancel(), tag=tag))
predict_state.cancel_sent = True

def _assert_state(self, state: WorkerState) -> None:
if self._state != state:
Expand Down Expand Up @@ -268,48 +319,14 @@ def _consume_events_inner(self) -> None:

# Main event loop
while self._child.is_alive():
# see if we have any new prediction requests

read_socks, _, _ = select.select(
[self._request_recv_conn, self._events], [], [], 0.1
)
if self._request_recv_conn in read_socks:
ev = self._request_recv_conn.recv()
if isinstance(ev, PredictionRequest):
with self._predictions_lock:
state = self._predictions_in_flight[ev.tag]

# Prepare payload (download URLPath objects)
# FIXME this blocks the event loop, which is bad in concurrent mode
try:
_prepare_payload(state.payload)
except Exception as e:
done = Done(error=True, error_detail=str(e))
self._publish(Envelope(done, state.tag))
self._complete_prediction(done, state.tag)
else:
# Start the prediction
self._events.send(
Envelope(
event=PredictionInput(payload=state.payload),
tag=state.tag,
)
)
elif isinstance(ev, CancelRequest):
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(ev.tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel_signal()
self._events.send(Envelope(event=Cancel(), tag=ev.tag))
predict_state.cancel_sent = True
else:
log.warn("unrecognized request event: {ev}")
# wait for events from the child worker
if not self._events.poll(0.1):
continue

if self._events in read_socks:
ev = self._events.recv()
self._publish(ev)
if isinstance(ev.event, Done):
self._complete_prediction(ev.event, ev.tag)
ev = self._events.recv()
self._publish(ev)
if isinstance(ev.event, Done):
self._complete_prediction(ev.event, ev.tag)

# If we dropped off the end off the end of the loop, it's because the
# child process died. First, process any remaining messages on the connection
Expand Down Expand Up @@ -844,13 +861,3 @@ def make_worker(
)
parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency)
return parent


def _prepare_payload(payload: Dict[str, Any]) -> None:
for k, v in payload.items():
# Check if v is an instance of URLPath
if isinstance(v, URLPath):
payload[k] = v.convert()
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(isinstance(item, URLPath) for item in v):
payload[k] = [item.convert() for item in v]

0 comments on commit f2ba468

Please sign in to comment.