Skip to content

Commit

Permalink
Add processes pre-loading
Browse files Browse the repository at this point in the history
  • Loading branch information
PawelPeczek-Roboflow committed Nov 25, 2024
1 parent 18ce9db commit 73da119
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 64 deletions.
5 changes: 3 additions & 2 deletions docker/config/cpu_http.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from multiprocessing import Process

from inference.core.cache import cache
Expand All @@ -17,14 +18,14 @@
MAX_ACTIVE_MODELS,
ACTIVE_LEARNING_ENABLED,
LAMBDA,
ENABLE_STREAM_API,
ENABLE_STREAM_API, STREAM_API_PRELOADED_PROCESSES,
)
from inference.models.utils import ROBOFLOW_MODEL_TYPES


if ENABLE_STREAM_API:
stream_manager_process = Process(
target=start,
target=partial(start, expected_warmed_up_pipelines=STREAM_API_PRELOADED_PROCESSES),
)
stream_manager_process.start()

Expand Down
5 changes: 3 additions & 2 deletions docker/config/gpu_http.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import multiprocessing
from functools import partial

from inference.core.cache import cache
from inference.core.env import (
MAX_ACTIVE_MODELS,
ACTIVE_LEARNING_ENABLED,
LAMBDA,
ENABLE_STREAM_API,
ENABLE_STREAM_API, STREAM_API_PRELOADED_PROCESSES,
)
from inference.core.interfaces.http.http_api import HttpInterface
from inference.core.interfaces.stream_manager.manager_app.app import start
Expand All @@ -24,7 +25,7 @@
if ENABLE_STREAM_API:
multiprocessing_context = multiprocessing.get_context(method="spawn")
stream_manager_process = multiprocessing_context.Process(
target=start,
target=partial(start, expected_warmed_up_pipelines=STREAM_API_PRELOADED_PROCESSES),
)
stream_manager_process.start()

Expand Down
1 change: 1 addition & 0 deletions docker/dockerfiles/Dockerfile.onnx.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,6 @@ ENV CORE_MODEL_OWLV2_ENABLED=True
ENV ENABLE_STREAM_API=True
ENV ENABLE_WORKFLOWS_PROFILING=True
ENV ENABLE_PROMETHEUS=True
ENV STREAM_API_PRELOADED_PROCESSES=2

ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
2 changes: 2 additions & 0 deletions docker/dockerfiles/Dockerfile.onnx.gpu.dev
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ ENV API_LOGGING_ENABLED=True
ENV LMM_ENABLED=True
ENV CORE_MODEL_SAM2_ENABLED=True
ENV CORE_MODEL_OWLV2_ENABLED=True
ENV ENABLE_STREAM_API=True
ENV ENABLE_WORKFLOWS_PROFILING=True
ENV ENABLE_PROMETHEUS=True
ENV STREAM_API_PRELOADED_PROCESSES=2

ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
1 change: 1 addition & 0 deletions docker/dockerfiles/Dockerfile.onnx.jetson.4.5.0
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,6 @@ ENV RUNS_ON_JETSON=True
ENV ENABLE_STREAM_API=True
ENV ENABLE_WORKFLOWS_PROFILING=True
ENV ENABLE_PROMETHEUS=True
ENV ENABLE_STREAM_API=True

ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
1 change: 1 addition & 0 deletions docker/dockerfiles/Dockerfile.onnx.jetson.4.6.1
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,6 @@ ENV RUNS_ON_JETSON=True
ENV ENABLE_STREAM_API=True
ENV ENABLE_WORKFLOWS_PROFILING=True
ENV ENABLE_PROMETHEUS=True
ENV ENABLE_STREAM_API=True

ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
1 change: 1 addition & 0 deletions docker/dockerfiles/Dockerfile.onnx.jetson.5.1.1
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,6 @@ ENV RUNS_ON_JETSON=True
ENV ENABLE_STREAM_API=True
ENV ENABLE_WORKFLOWS_PROFILING=True
ENV ENABLE_PROMETHEUS=True
ENV ENABLE_STREAM_API=True

ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
3 changes: 2 additions & 1 deletion docker/dockerfiles/Dockerfile.onnx.jetson.6.0.0
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ ENV VERSION_CHECK_MODE=continuous \
RUNS_ON_JETSON=True \
ENABLE_STREAM_API=True \
ENABLE_WORKFLOWS_PROFILING=True \
ENABLE_PROMETHEUS=True
ENABLE_PROMETHEUS=True \
ENV ENABLE_STREAM_API=True

# Expose the application port
EXPOSE 9001
Expand Down
1 change: 1 addition & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@
)

ENABLE_STREAM_API = str2bool(os.getenv("ENABLE_STREAM_API", "False"))
STREAM_API_PRELOADED_PROCESSES = int(os.getenv("STREAM_API_PRELOADED_PROCESSES"), "0")

RUNS_ON_JETSON = str2bool(os.getenv("RUNS_ON_JETSON", "False"))

Expand Down
173 changes: 114 additions & 59 deletions inference/core/interfaces/stream_manager/manager_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from socketserver import BaseRequestHandler, BaseServer
from threading import Lock, Thread
from types import FrameType
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4

from inference.core import logger
Expand All @@ -27,6 +27,7 @@
TYPE_KEY,
CommandType,
ErrorType,
ManagedInferencePipeline,
OperationStatus,
)
from inference.core.interfaces.stream_manager.manager_app.errors import (
Expand All @@ -44,7 +45,8 @@
RoboflowTCPServer,
)

PROCESSES_TABLE: Dict[str, Tuple[Process, Queue, Queue, Lock]] = {}
PROCESSES_TABLE: Dict[str, ManagedInferencePipeline] = {}
PROCESSES_TABLE_LOCK = Lock()
HEADER_SIZE = 4
SOCKET_BUFFER_SIZE = 16384
HOST = os.getenv("STREAM_MANAGER_HOST", "127.0.0.1")
Expand All @@ -58,7 +60,7 @@ def __init__(
request: socket.socket,
client_address: Any,
server: BaseServer,
processes_table: Dict[str, Tuple[Process, Queue, Queue, Lock]],
processes_table: Dict[str, ManagedInferencePipeline],
):
self._processes_table = processes_table # in this case it's required to set the state of class before superclass init - as it invokes ()
super().__init__(request, client_address, server)
Expand Down Expand Up @@ -140,7 +142,9 @@ def _list_pipelines(self, request_id: str) -> None:
serialised_response = prepare_response(
request_id=request_id,
response={
"pipelines": list(self._processes_table.keys()),
"pipelines": [
k for k, v in self._processes_table.items() if not v.is_idle
],
STATUS_KEY: OperationStatus.SUCCESS,
},
pipeline_id=None,
Expand All @@ -153,65 +157,47 @@ def _list_pipelines(self, request_id: str) -> None:
)

def _initialise_pipeline(self, request_id: str, command: dict) -> None:
pipeline_id = str(uuid4())
command_queue = Queue()
responses_queue = Queue()
inference_pipeline_manager = InferencePipelineManager.init(
pipeline_id=pipeline_id,
command_queue=command_queue,
responses_queue=responses_queue,
)
inference_pipeline_manager.start()
self._processes_table[pipeline_id] = (
inference_pipeline_manager,
command_queue,
responses_queue,
Lock(),
managed_pipeline = get_or_spawn_pipeline_process(
processes_table=self._processes_table,
)
command_queue.put((request_id, command))
managed_pipeline.command_queue.put((request_id, command))
response = get_response_ignoring_thrash(
responses_queue=responses_queue, matching_request_id=request_id
responses_queue=managed_pipeline.responses_queue,
matching_request_id=request_id,
)
serialised_response = prepare_response(
request_id=request_id, response=response, pipeline_id=pipeline_id
request_id=request_id,
response=response,
pipeline_id=managed_pipeline.pipeline_id,
)
send_data_trough_socket(
target=self.request,
header_size=HEADER_SIZE,
data=serialised_response,
request_id=request_id,
pipeline_id=pipeline_id,
pipeline_id=managed_pipeline.pipeline_id,
)

def _start_webrtc(self, request_id: str, command: dict):
pipeline_id = str(uuid4())
command_queue = Queue()
responses_queue = Queue()
inference_pipeline_manager = InferencePipelineManager.init(
pipeline_id=pipeline_id,
command_queue=command_queue,
responses_queue=responses_queue,
)
inference_pipeline_manager.start()
self._processes_table[pipeline_id] = (
inference_pipeline_manager,
command_queue,
responses_queue,
Lock(),
managed_pipeline = get_or_spawn_pipeline_process(
processes_table=self._processes_table,
)
command_queue.put((request_id, command))
managed_pipeline.command_queue.put((request_id, command))
response = get_response_ignoring_thrash(
responses_queue=responses_queue, matching_request_id=request_id
responses_queue=managed_pipeline.responses_queue,
matching_request_id=request_id,
)
serialised_response = prepare_response(
request_id=request_id, response=response, pipeline_id=pipeline_id
request_id=request_id,
response=response,
pipeline_id=managed_pipeline.pipeline_id,
)
send_data_trough_socket(
target=self.request,
header_size=HEADER_SIZE,
data=serialised_response,
request_id=request_id,
pipeline_id=pipeline_id,
pipeline_id=managed_pipeline.pipeline_id,
)

def _terminate_pipeline(
Expand Down Expand Up @@ -246,7 +232,7 @@ def _terminate_pipeline(


def handle_command(
processes_table: Dict[str, Tuple[Process, Queue, Queue, Lock]],
processes_table: Dict[str, ManagedInferencePipeline],
request_id: str,
pipeline_id: str,
command: dict,
Expand All @@ -257,11 +243,12 @@ def handle_command(
error_type=ErrorType.NOT_FOUND,
public_error_message=f"Could not find InferencePipeline with id={pipeline_id}.",
)
_, command_queue, responses_queue, command_lock = processes_table[pipeline_id]
with command_lock:
command_queue.put((request_id, command))
managed_pipeline = processes_table[pipeline_id]
with managed_pipeline.operation_lock:
managed_pipeline.command_queue.put((request_id, command))
return get_response_ignoring_thrash(
responses_queue=responses_queue, matching_request_id=request_id
responses_queue=managed_pipeline.responses_queue,
matching_request_id=request_id,
)


Expand All @@ -282,29 +269,31 @@ def execute_termination(
frame: FrameType,
processes_table: Dict[str, Tuple[Process, Queue, Queue, Lock]],
) -> None:
pipeline_ids = list(processes_table.keys())
for pipeline_id in pipeline_ids:
logger.info(f"Terminating pipeline: {pipeline_id}")
processes_table[pipeline_id][0].terminate()
logger.info(f"Pipeline: {pipeline_id} terminated.")
logger.info(f"Joining pipeline: {pipeline_id}")
processes_table[pipeline_id][0].join()
logger.info(f"Pipeline: {pipeline_id} joined.")
logger.info(f"Termination handler completed.")
sys.exit(0)
with PROCESSES_TABLE_LOCK:
pipeline_ids = list(processes_table.keys())
for pipeline_id in pipeline_ids:
logger.info(f"Terminating pipeline: {pipeline_id}")
processes_table[pipeline_id][0].terminate()
logger.info(f"Pipeline: {pipeline_id} terminated.")
logger.info(f"Joining pipeline: {pipeline_id}")
processes_table[pipeline_id][0].join()
logger.info(f"Pipeline: {pipeline_id} joined.")
logger.info(f"Termination handler completed.")
sys.exit(0)


def join_inference_pipeline(
processes_table: Dict[str, Tuple[Process, Queue, Queue, Lock]], pipeline_id: str
processes_table: Dict[str, ManagedInferencePipeline], pipeline_id: str
) -> None:
inference_pipeline_manager, *_ = processes_table[pipeline_id]
inference_pipeline_manager = processes_table[pipeline_id].pipeline_manager
inference_pipeline_manager.join()
del processes_table[pipeline_id]


def check_process_health() -> None:
while True:
for pipeline_id, (process, *_) in list(PROCESSES_TABLE.items()):
for pipeline_id, managed_pipeline in list(PROCESSES_TABLE.items()):
process = managed_pipeline.pipeline_manager
if not process.is_alive():
logger.warning(
"Process for pipeline_id=%s is not alive. Terminating...",
Expand Down Expand Up @@ -361,7 +350,70 @@ def check_process_health() -> None:
time.sleep(1)


def start() -> None:
def get_or_spawn_pipeline_process(
processes_table: Dict[str, ManagedInferencePipeline],
) -> ManagedInferencePipeline:
with PROCESSES_TABLE_LOCK:
idle_pipelines = get_idle_pipelines_id(processes_table=processes_table)
if len(idle_pipelines) > 0:
chosen_pipeline = processes_table[idle_pipelines[0]]
chosen_pipeline.is_idle = False
return chosen_pipeline
new_pipeline_id = spawn_managed_pipeline_process(
processes_table=processes_table,
mark_as_idle=False,
)
return processes_table[new_pipeline_id]


def ensure_idle_pipelines_warmed_up(expected_warmed_up_pipelines: int) -> None:
while True:
with PROCESSES_TABLE_LOCK:
idle_pipelines = len(get_idle_pipelines_id(processes_table=PROCESSES_TABLE))
if idle_pipelines < expected_warmed_up_pipelines:
_ = spawn_managed_pipeline_process(processes_table=PROCESSES_TABLE)
time.sleep(5)


def get_idle_pipelines_id(
processes_table: Dict[str, ManagedInferencePipeline]
) -> List[str]:
return [
pipeline_id
for pipeline_id, managed_pipeline in processes_table.items()
if managed_pipeline.is_idle
]


def spawn_managed_pipeline_process(
processes_table: Dict[str, ManagedInferencePipeline],
mark_as_idle: bool = True,
) -> str:
logger.info(
f"Spawning new managed InferencePipeline process. Idle flag: {mark_as_idle}"
)
pipeline_id = str(uuid4())
command_queue = Queue()
responses_queue = Queue()
inference_pipeline_manager = InferencePipelineManager.init(
pipeline_id=pipeline_id,
command_queue=command_queue,
responses_queue=responses_queue,
)
inference_pipeline_manager.start()
processes_table[pipeline_id] = ManagedInferencePipeline(
pipeline_id=pipeline_id,
pipeline_manager=inference_pipeline_manager,
command_queue=command_queue,
responses_queue=responses_queue,
operation_lock=Lock(),
is_idle=mark_as_idle,
)
logger.info(f"Spawned new InferencePipeline process with id: {pipeline_id}")
return pipeline_id


def start(expected_warmed_up_pipelines: int = 0) -> None:
signal.signal(
signal.SIGINT, partial(execute_termination, processes_table=PROCESSES_TABLE)
)
Expand All @@ -372,6 +424,9 @@ def start() -> None:
# check process health in daemon thread
Thread(target=check_process_health, daemon=True).start()

# keep expected number of processes ready for processing
Thread(target=ensure_idle_pipelines_warmed_up, daemon=True).start()

with RoboflowTCPServer(
server_address=(HOST, PORT),
handler_class=partial(
Expand Down
Loading

0 comments on commit 73da119

Please sign in to comment.