Skip to content

Commit

Permalink
Refactor middlewares to avoid using BaseHTTPMiddleware (#1555)
Browse files Browse the repository at this point in the history
  • Loading branch information
purplesmoke05 authored Sep 5, 2024
1 parent b1a10c7 commit 4a3d44d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 44 deletions.
8 changes: 2 additions & 6 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pydantic_core import ArgsKwargs, ErrorDetails
from sqlalchemy.exc import OperationalError
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware

from app import log
Expand Down Expand Up @@ -139,18 +138,15 @@ def root():
# MIDDLEWARE
###############################################################

strip_trailing_slash = StripTrailingSlashMiddleware()
response_logger = ResponseLoggerMiddleware()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(BaseHTTPMiddleware, dispatch=strip_trailing_slash)
app.add_middleware(BaseHTTPMiddleware, dispatch=response_logger)
app.add_middleware(ResponseLoggerMiddleware)
app.add_middleware(StripTrailingSlashMiddleware)


###############################################################
Expand Down
52 changes: 30 additions & 22 deletions app/middleware/response_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
"""

import logging
from datetime import UTC, datetime
import time

from fastapi import Request, Response
from starlette.middleware.base import RequestResponseEndpoint
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from app import config, log

Expand All @@ -43,27 +42,36 @@
class ResponseLoggerMiddleware:
"""Response Logger Middleware"""

def __init__(self):
pass
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

async def __call__(
self, req: Request, call_next: RequestResponseEndpoint
) -> Response:
# Before process request
request_start_time = datetime.now(UTC).replace(tzinfo=None)
request_start_time = time.monotonic()

method = scope["method"]
path = scope["path"]
query_string = scope["query_string"].decode("utf-8")
status_code = None

async def send_wrapper(message: Message):
nonlocal status_code
if message["type"] == "http.response.start":
status_code = message["status"]
await send(message)

# Process request
res: Response = await call_next(req)
await self.app(scope, receive, send_wrapper)

# After process request
if req.url.path != "/":
response_time = (
datetime.now(UTC).replace(tzinfo=None) - request_start_time
).total_seconds()
if req.url.query:
log_msg = f"{req.method} {req.url.path}?{req.url.query} {res.status_code} ({response_time}sec)"
else:
log_msg = f"{req.method} {req.url.path} {res.status_code} ({response_time}sec)"
ACCESS_LOG.info(log_msg)

return res
response_time = time.monotonic() - request_start_time
if query_string:
log_msg = (
f"{method} {path}?{query_string} {status_code} ({response_time:.6f}sec)"
)
else:
log_msg = f"{method} {path} {status_code} ({response_time:.6f}sec)"
ACCESS_LOG.info(log_msg)
40 changes: 24 additions & 16 deletions app/middleware/strip_trailing_slash.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
SPDX-License-Identifier: Apache-2.0
"""

from fastapi import Request, Response
from starlette.middleware.base import RequestResponseEndpoint
from urllib.parse import urlparse, urlunparse

from starlette.types import ASGIApp, Receive, Scope, Send

from app import log

Expand All @@ -34,22 +35,29 @@ class StripTrailingSlashMiddleware:
* this middleware replaces it to "/Admin/Tokens" for avoiding redirect(307).
"""

def __init__(self):
pass
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return

async def __call__(
self, req: Request, call_next: RequestResponseEndpoint
) -> Response:
# Before process request
if req.url.path != "/" and req.url.path[-1] == "/":
replace_path = req.url.path[:-1]
req._url = req.url.replace(path=replace_path)
req.scope["path"] = replace_path
req.scope["raw_path"] = replace_path.encode()
path = scope["path"]
if path != "/" and path.endswith("/"):
# Remove trailing slash
new_path = path.rstrip("/")

# Process request
res: Response = await call_next(req)
# Update scope
scope["path"] = new_path
scope["raw_path"] = new_path.encode()

# After process request
# Update url in scope
if "url" in scope:
url_parts = list(urlparse(scope["url"]))
url_parts[2] = new_path # Update path
scope["url"] = urlunparse(url_parts)

return res
# Process request
await self.app(scope, receive, send)

0 comments on commit 4a3d44d

Please sign in to comment.