Skip to content

Commit

Permalink
🎨 Improve profiling middleware (#5935)
Browse files Browse the repository at this point in the history
  • Loading branch information
bisgaard-itis authored Jun 13, 2024
1 parent 641b328 commit ab8bb89
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 94 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
from aiohttp.web import HTTPInternalServerError, Request, StreamResponse, middleware
from pyinstrument import Profiler
from servicelib.mimetype_constants import (
MIMETYPE_APPLICATION_JSON,
MIMETYPE_APPLICATION_ND_JSON,
)

from .._utils_profiling_middleware import append_profile
from ..utils_profiling_middleware import _is_profiling, _profiler, append_profile


@middleware
async def profiling_middleware(request: Request, handler):
profiler: Profiler | None = None
if request.headers.get("x-profile") is not None:
profiler = Profiler(async_mode="enabled")
profiler.start()
try:
if _profiler.is_running or (_profiler.last_session is not None):
raise HTTPInternalServerError(
reason="Profiler is already running. Only a single request can be profiled at any given time.",
headers={},
)
_profiler.reset()
_is_profiling.set(True)

response = await handler(request)
with _profiler:
response = await handler(request)

if profiler is None:
return response
if response.content_type != MIMETYPE_APPLICATION_JSON:
raise HTTPInternalServerError(
reason=f"Profiling middleware is not compatible with {response.content_type=}",
headers={},
)
if response.content_type != MIMETYPE_APPLICATION_JSON:
raise HTTPInternalServerError(
reason=f"Profiling middleware is not compatible with {response.content_type=}",
headers={},
)

stream_response = StreamResponse(
status=response.status,
reason=response.reason,
headers=response.headers,
)
stream_response.content_type = MIMETYPE_APPLICATION_ND_JSON
await stream_response.prepare(request)
await stream_response.write(response.body)
profiler.stop()
await stream_response.write(
append_profile(
"\n", profiler.output_text(unicode=True, color=True, show_all=True)
).encode()
)
await stream_response.write_eof()
return stream_response
stream_response = StreamResponse(
status=response.status,
reason=response.reason,
headers=response.headers,
)
stream_response.content_type = MIMETYPE_APPLICATION_ND_JSON
await stream_response.prepare(request)
await stream_response.write(response.body)
await stream_response.write(
append_profile(
"\n", _profiler.output_text(unicode=True, color=True, show_all=True)
).encode()
)
await stream_response.write_eof()
finally:
_profiler.reset()
return stream_response

return await handler(request)
34 changes: 26 additions & 8 deletions packages/service-library/src/servicelib/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from contextlib import suppress
from dataclasses import dataclass
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Deque
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Deque

from .utils_profiling_middleware import dont_profile, is_profiling, profile

logger = logging.getLogger(__name__)

Expand All @@ -30,6 +32,13 @@ class Context:
task: asyncio.Task | None = None


@dataclass
class QueueElement:
do_profile: bool = False
input: Awaitable | None = None
output: Any | None = None


_sequential_jobs_contexts: dict[str, Context] = {}


Expand Down Expand Up @@ -138,15 +147,18 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
if not context.initialized:
context.initialized = True

async def worker(in_q: Queue, out_q: Queue) -> None:
async def worker(in_q: Queue[QueueElement], out_q: Queue) -> None:
while True:
awaitable = await in_q.get()
element = await in_q.get()
in_q.task_done()
# check if requested to shutdown
if awaitable is None:
break
try:
result = await awaitable
do_profile = element.do_profile
awaitable = element.input
if awaitable is None:
break
with profile(do_profile):
result = await awaitable
except Exception as e: # pylint: disable=broad-except
result = e
await out_q.put(result)
Expand All @@ -161,9 +173,15 @@ async def worker(in_q: Queue, out_q: Queue) -> None:
worker(context.in_queue, context.out_queue)
)

await context.in_queue.put(decorated_function(*args, **kwargs))
with dont_profile():
# ensure profiler is disabled in order to capture profile of endpoint code
queue_input = QueueElement(
input=decorated_function(*args, **kwargs),
do_profile=is_profiling(),
)
await context.in_queue.put(queue_input)
wrapped_result = await context.out_queue.get()

wrapped_result = await context.out_queue.get()
if isinstance(wrapped_result, Exception):
raise wrapped_result

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Any
from typing import Any, Final

from fastapi import FastAPI
from pyinstrument import Profiler
from servicelib.aiohttp import status
from servicelib.mimetype_constants import MIMETYPE_APPLICATION_JSON
from starlette.requests import Request

from .._utils_profiling_middleware import append_profile, check_response_headers
from ..utils_profiling_middleware import (
_is_profiling,
_profiler,
append_profile,
check_response_headers,
)


def is_last_response(response_headers: dict[bytes, bytes], message: dict[str, Any]):
Expand All @@ -28,43 +33,66 @@ class ProfilerMiddleware:

def __init__(self, app: FastAPI):
self._app: FastAPI = app
self._profile_header_trigger: str = "x-profile"
self._profile_header_trigger: Final[str] = "x-profile"

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self._app(scope, receive, send)
return

profiler: Profiler | None = None
request: Request = Request(scope)
request_headers = dict(request.headers)
response_headers: dict[bytes, bytes] = {}

if request_headers.get(self._profile_header_trigger) is not None:
if request_headers.get(self._profile_header_trigger) is None:
await self._app(scope, receive, send)
return

if _profiler.is_running or (_profiler.last_session is not None):
response = {
"type": "http.response.start",
"status": status.HTTP_500_INTERNAL_SERVER_ERROR,
"headers": [
(b"content-type", b"text/plain"),
],
}
await send(response)
response_body = {
"type": "http.response.body",
"body": b"Profiler is already running. Only a single request can be profiled at any give time.",
}
await send(response_body)
return

try:
request_headers.pop(self._profile_header_trigger)
scope["headers"] = [
(k.encode("utf8"), v.encode("utf8")) for k, v in request_headers.items()
]
profiler = Profiler(async_mode="enabled")
profiler.start()
_profiler.start()
_is_profiling.set(True)

async def _send_wrapper(message):
if isinstance(profiler, Profiler):
nonlocal response_headers
if message["type"] == "http.response.start":
response_headers = dict(message.get("headers"))
message["headers"] = check_response_headers(response_headers)
elif message["type"] == "http.response.body":
if is_last_response(response_headers, message):
profiler.stop()
message["body"] = append_profile(
message["body"].decode(),
profiler.output_text(
async def _send_wrapper(message):
if _is_profiling.get():
nonlocal response_headers
if message["type"] == "http.response.start":
response_headers = dict(message.get("headers"))
message["headers"] = check_response_headers(response_headers)
elif message["type"] == "http.response.body":
if is_last_response(response_headers, message):
_profiler.stop()
profile_text = _profiler.output_text(
unicode=True, color=True, show_all=True
),
).encode()
else:
message["more_body"] = True
await send(message)
)
_profiler.reset()
message["body"] = append_profile(
message["body"].decode(), profile_text
).encode()
else:
message["more_body"] = True
await send(message)

await self._app(scope, receive, _send_wrapper)

await self._app(scope, receive, _send_wrapper)
finally:
_profiler.reset()
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import contextvars
import json
from contextlib import contextmanager
from typing import Iterator

from pyinstrument import Profiler
from servicelib.mimetype_constants import (
MIMETYPE_APPLICATION_JSON,
MIMETYPE_APPLICATION_ND_JSON,
)

_profiler = Profiler(async_mode="enabled")
_is_profiling = contextvars.ContextVar("_is_profiling", default=False)


def is_profiling() -> bool:
return _is_profiling.get()


@contextmanager
def profile(do_profile: bool | None = None) -> Iterator[None]:
"""Context manager which temporarily removes request profiler from context"""
if do_profile is None:
do_profile = _is_profiling.get()
if do_profile:
try:
_profiler.start()
yield
finally:
_profiler.stop()
else:
yield None


@contextmanager
def dont_profile() -> Iterator[None]:
if _is_profiling.get():
try:
_profiler.stop()
yield
finally:
_profiler.start()
else:
yield


def append_profile(body: str, profile_text: str) -> str:
try:
json.loads(body)
body += "\n" if not body.endswith("\n") else ""
except json.decoder.JSONDecodeError:
pass
body += json.dumps({"profile": profile_text})
return body


def check_response_headers(
response_headers: dict[bytes, bytes]
) -> list[tuple[bytes, bytes]]:
original_content_type: str = response_headers[b"content-type"].decode()
assert original_content_type in {
MIMETYPE_APPLICATION_ND_JSON,
MIMETYPE_APPLICATION_JSON,
} # nosec
headers: dict = {}
headers[b"content-type"] = MIMETYPE_APPLICATION_ND_JSON.encode()
return list(headers.items())
10 changes: 8 additions & 2 deletions services/director-v2/tests/integration/02/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,13 +529,19 @@ async def _port_forward_legacy_service( # pylint: disable=redefined-outer-name
# Legacy services are started --endpoint-mode dnsrr, it needs to
# be changed to vip otherwise the port forward will not work
result = run_command(f"docker service update {service_name} --endpoint-mode=vip")
assert "verify: Service converged" in result
assert (
"verify: Service converged" in result
or f"verify: Service {service_name} converged" in result
)

# Finally forward the port on a random assigned port.
result = run_command(
f"docker service update {service_name} --publish-add :{internal_port}"
)
assert "verify: Service converged" in result
assert (
"verify: Service converged" in result
or f"verify: Service {service_name} converged" in result
)

# inspect service and fetch the port
async with aiodocker.Docker() as docker_client:
Expand Down
1 change: 1 addition & 0 deletions tmp_urls.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
\"http://10.43.103.193.nip.io:8006/v0/me\" -X GET -H "accept: application/json" -H "Authorization: Basic dGVzdF9iODkxNjUwZmViZjY2OTNlZjc3MToxNzliM2E4OTRiNTY0ZGY5NjExYzY5ZmE4NDcxNjNiYzhmYzdkMGY0"

0 comments on commit ab8bb89

Please sign in to comment.