Skip to content

Commit

Permalink
WIP: download input events off the event consumer loop
Browse files Browse the repository at this point in the history
  • Loading branch information
philandstuff committed Dec 20, 2024
1 parent 392819d commit c1b7f72
Showing 1 changed file with 65 additions and 36 deletions.
101 changes: 65 additions & 36 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import uuid
import warnings
import weakref
from concurrent import futures
from concurrent.futures import Future, ThreadPoolExecutor
from enum import Enum, auto, unique
from multiprocessing.connection import Connection
Expand Down Expand Up @@ -130,14 +131,16 @@ def __init__(
self._request_send_conn = send_conn
self._request_recv_conn = recv_conn

self._pool = ThreadPoolExecutor(max_workers=1)
self._event_consumer_pool = ThreadPoolExecutor(max_workers=1)
self._prediction_prepare_pool = ThreadPoolExecutor(max_workers=max_concurrency)
self._input_download_pool = ThreadPoolExecutor(max_workers=8)
self._event_consumer = None

def setup(self) -> "Future[Done]":
self._assert_state(WorkerState.NEW)
self._state = WorkerState.STARTING
self._child.start()
self._event_consumer = self._pool.submit(self._consume_events)
self._event_consumer = self._event_consumer_pool.submit(self._consume_events)
return self._setup_result

def predict(
Expand All @@ -163,9 +166,65 @@ def predict(
self._assert_state(WorkerState.READY)
result = Future()
self._predictions_in_flight[tag] = PredictionState(tag, payload, result)
self._request_send_conn.send(PredictionRequest(tag))

# Prepare payload asynchronously (download URLPath objects)
fut = self._prediction_prepare_pool.submit(self._prepare_payload(payload))
# then start the prediction
fut.add_done_callback(self._start_prediction(tag))
return result

def _prepare_payload(self, payload: Dict[str, Any]) -> Callable[[], Dict[str, Any]]:
def prepare_payload() -> Dict[str, Any]:
to_await = []
futs = {}
for k, v in payload.items():
# Check if v is an instance of URLPath
if isinstance(v, URLPath):
futs[k] = self._input_download_pool.submit(v.convert)
to_await.append(futs[k])
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(
isinstance(item, URLPath) for item in v
):
futs[k] = [
self._input_download_pool.submit(item.convert) for item in v
]
to_await += futs[k]
futures.wait(to_await, return_when=futures.FIRST_EXCEPTION)
for k, v in futs.items():
if isinstance(v, list):
payload[k] = []
for fut in v:
# the future may not be done if and only if another
# future finished with an exception
if fut.done():
payload[k].append(fut.result())
elif isinstance(v, Future):
if v.done():
payload[k] = v.result()
return payload

return prepare_payload

def _start_prediction(
self, tag: Optional[str]
) -> Callable[["Future[Dict[str,Any]]"], None]:
def payload_callback(fut: "Future[Dict[str,Any]]") -> None:
if fut.exception():
done = Done(error=True, error_detail=str(fut.exception()))
self._publish(Envelope(done, tag))
self._complete_prediction(done, tag)
return
payload = fut.result()
self._events.send(
Envelope(
event=PredictionInput(payload=payload),
tag=tag,
)
)

return payload_callback

def subscribe(
self,
subscriber: Callable[[_PublicEventType], None],
Expand Down Expand Up @@ -195,7 +254,7 @@ def shutdown(self, timeout: Optional[float] = None) -> None:
if self._event_consumer:
self._event_consumer.result(timeout=timeout)

self._pool.shutdown()
self._event_consumer_pool.shutdown()

def terminate(self) -> None:
"""
Expand All @@ -209,7 +268,7 @@ def terminate(self) -> None:
self._child.terminate()
self._child.join()

self._pool.shutdown(wait=False)
self._event_consumer_pool.shutdown(wait=False)

def cancel(self, tag: Optional[str] = None) -> None:
self._request_send_conn.send(CancelRequest(tag))
Expand Down Expand Up @@ -275,27 +334,7 @@ def _consume_events_inner(self) -> None:
)
if self._request_recv_conn in read_socks:
ev = self._request_recv_conn.recv()
if isinstance(ev, PredictionRequest):
with self._predictions_lock:
state = self._predictions_in_flight[ev.tag]

# Prepare payload (download URLPath objects)
# FIXME this blocks the event loop, which is bad in concurrent mode
try:
_prepare_payload(state.payload)
except Exception as e:
done = Done(error=True, error_detail=str(e))
self._publish(Envelope(done, state.tag))
self._complete_prediction(done, state.tag)
else:
# Start the prediction
self._events.send(
Envelope(
event=PredictionInput(payload=state.payload),
tag=state.tag,
)
)
elif isinstance(ev, CancelRequest):
if isinstance(ev, CancelRequest):
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(ev.tag)
if predict_state and not predict_state.cancel_sent:
Expand Down Expand Up @@ -844,13 +883,3 @@ def make_worker(
)
parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency)
return parent


def _prepare_payload(payload: Dict[str, Any]) -> None:
for k, v in payload.items():
# Check if v is an instance of URLPath
if isinstance(v, URLPath):
payload[k] = v.convert()
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(isinstance(item, URLPath) for item in v):
payload[k] = [item.convert() for item in v]

0 comments on commit c1b7f72

Please sign in to comment.