Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download inputs off the event consumer loop #2092

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 85 additions & 62 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import inspect
import multiprocessing
import os
import select
import signal
import sys
import threading
Expand All @@ -13,6 +12,7 @@
import uuid
import warnings
import weakref
from concurrent import futures
from concurrent.futures import Future, ThreadPoolExecutor
from enum import Enum, auto, unique
from multiprocessing.connection import Connection
Expand Down Expand Up @@ -126,18 +126,16 @@ def __init__(
self._predictions_lock = threading.Lock()
self._predictions_in_flight: Dict[Optional[str], PredictionState] = {}

recv_conn, send_conn = _spawn.Pipe(duplex=False)
self._request_send_conn = send_conn
self._request_recv_conn = recv_conn

self._pool = ThreadPoolExecutor(max_workers=1)
self._event_consumer_pool = ThreadPoolExecutor(max_workers=1)
self._prediction_start_pool = ThreadPoolExecutor(max_workers=max_concurrency)
self._input_download_pool = ThreadPoolExecutor(max_workers=8)
self._event_consumer = None

def setup(self) -> "Future[Done]":
self._assert_state(WorkerState.NEW)
self._state = WorkerState.STARTING
self._child.start()
self._event_consumer = self._pool.submit(self._consume_events)
self._event_consumer = self._event_consumer_pool.submit(self._consume_events)
return self._setup_result

def predict(
Expand All @@ -163,9 +161,73 @@ def predict(
self._assert_state(WorkerState.READY)
result = Future()
self._predictions_in_flight[tag] = PredictionState(tag, payload, result)
self._request_send_conn.send(PredictionRequest(tag))

self._prediction_start_pool.submit(self._start_prediction(tag, payload))
return result

def _start_prediction(
self, tag: Optional[str], payload: Dict[str, Any]
) -> Callable[[], None]:
def start_prediction() -> None:
try:
to_await = []
futs = {}
# Prepare payload asynchronously (download URLPath objects)
for k, v in payload.items():
# Check if v is an instance of URLPath
if isinstance(v, URLPath):
futs[k] = self._input_download_pool.submit(v.convert)
to_await.append(futs[k])
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(
isinstance(item, URLPath) for item in v
):
futs[k] = [
self._input_download_pool.submit(item.convert) for item in v
]
to_await += futs[k]
done, not_done = futures.wait(
to_await, return_when=futures.FIRST_EXCEPTION
)

if len(not_done) > 0:
# if any future isn't done, this is because one of the
# futures raised an exception. first we cancel outstanding
# work
for fut in not_done:
fut.cancel()
# then we find an exception to raise
for fut in done:
fut.result() # raises if the future finished with an exception
# we should never get here
raise Exception(
"Internal error: lost track of exception while downloading input files"
)

# all futures are done. some might still have raised an
# exception, but when we call fut.result() that will re-raise
# and do the right thing
for k, v in futs.items():
if isinstance(v, list):
payload[k] = []
for fut in v:
payload[k].append(fut.result())
elif isinstance(v, Future):
payload[k] = v.result()
# send the prediction to the child to start
self._events.send(
Envelope(
event=PredictionInput(payload=payload),
tag=tag,
)
)
except Exception as e:
done = Done(error=True, error_detail=str(e))
self._publish(Envelope(done, tag))
self._complete_prediction(done, tag)

return start_prediction

def subscribe(
self,
subscriber: Callable[[_PublicEventType], None],
Expand Down Expand Up @@ -195,7 +257,7 @@ def shutdown(self, timeout: Optional[float] = None) -> None:
if self._event_consumer:
self._event_consumer.result(timeout=timeout)

self._pool.shutdown()
self._event_consumer_pool.shutdown()

def terminate(self) -> None:
"""
Expand All @@ -209,10 +271,15 @@ def terminate(self) -> None:
self._child.terminate()
self._child.join()

self._pool.shutdown(wait=False)
self._event_consumer_pool.shutdown(wait=False)

def cancel(self, tag: Optional[str] = None) -> None:
self._request_send_conn.send(CancelRequest(tag))
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel_signal()
self._events.send(Envelope(event=Cancel(), tag=tag))
predict_state.cancel_sent = True

def _assert_state(self, state: WorkerState) -> None:
if self._state != state:
Expand Down Expand Up @@ -268,48 +335,14 @@ def _consume_events_inner(self) -> None:

# Main event loop
while self._child.is_alive():
# see if we have any new prediction requests

read_socks, _, _ = select.select(
[self._request_recv_conn, self._events], [], [], 0.1
)
if self._request_recv_conn in read_socks:
ev = self._request_recv_conn.recv()
if isinstance(ev, PredictionRequest):
with self._predictions_lock:
state = self._predictions_in_flight[ev.tag]

# Prepare payload (download URLPath objects)
# FIXME this blocks the event loop, which is bad in concurrent mode
try:
_prepare_payload(state.payload)
except Exception as e:
done = Done(error=True, error_detail=str(e))
self._publish(Envelope(done, state.tag))
self._complete_prediction(done, state.tag)
else:
# Start the prediction
self._events.send(
Envelope(
event=PredictionInput(payload=state.payload),
tag=state.tag,
)
)
elif isinstance(ev, CancelRequest):
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(ev.tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel_signal()
self._events.send(Envelope(event=Cancel(), tag=ev.tag))
predict_state.cancel_sent = True
else:
log.warn("unrecognized request event: {ev}")
# wait for events from the child worker
if not self._events.poll(0.1):
continue

if self._events in read_socks:
ev = self._events.recv()
self._publish(ev)
if isinstance(ev.event, Done):
self._complete_prediction(ev.event, ev.tag)
ev = self._events.recv()
self._publish(ev)
if isinstance(ev.event, Done):
self._complete_prediction(ev.event, ev.tag)

# If we dropped off the end off the end of the loop, it's because the
# child process died. First, process any remaining messages on the connection
Expand Down Expand Up @@ -844,13 +877,3 @@ def make_worker(
)
parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency)
return parent


def _prepare_payload(payload: Dict[str, Any]) -> None:
for k, v in payload.items():
# Check if v is an instance of URLPath
if isinstance(v, URLPath):
payload[k] = v.convert()
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(isinstance(item, URLPath) for item in v):
payload[k] = [item.convert() for item in v]
Loading