Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow base path config for ALB handler #291

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The heart of Mangum is the adapter class. It is a configurable wrapper that allo
handler = Mangum(
app,
lifespan="auto",
api_gateway_base_path=None,
base_path=None,
custom_handlers=None,
text_mime_types=None,
)
Expand Down
12 changes: 11 additions & 1 deletion mangum/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from itertools import chain
from contextlib import ExitStack
from typing import List, Optional, Type
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
app: ASGI,
lifespan: LifespanMode = "auto",
api_gateway_base_path: str = "/",
jordaneremieff marked this conversation as resolved.
Show resolved Hide resolved
base_path: str = "/",
custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
text_mime_types: Optional[List[str]] = None,
exclude_headers: Optional[List[str]] = None,
Expand All @@ -55,8 +57,16 @@ def __init__(
self.lifespan = lifespan
self.custom_handlers = custom_handlers or []
exclude_headers = exclude_headers or []
if api_gateway_base_path and api_gateway_base_path != "/":
warnings.warn(
"`api_gateway_base_path` parameter is deprecated and will be "
"removed in future versions. Please use `base_path` instead.",
DeprecationWarning,
)
base_path = api_gateway_base_path

self.config = LambdaConfig(
api_gateway_base_path=api_gateway_base_path or "/",
base_path=base_path or "/",
text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES],
exclude_headers=[header.lower() for header in exclude_headers],
)
Expand Down
7 changes: 5 additions & 2 deletions mangum/handlers/alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
handle_base64_response_body,
handle_exclude_headers,
maybe_encode_body,
strip_base_path,
)
from mangum.types import (
Response,
Expand Down Expand Up @@ -104,7 +105,6 @@ def body(self) -> bytes:

@property
def scope(self) -> Scope:

headers = transform_headers(self.event)
list_headers = [list(x) for x in headers]
# Unique headers. If there are duplicates, it will use the last defined.
Expand All @@ -130,7 +130,10 @@ def scope(self) -> Scope:
"method": http_method,
"http_version": "1.1",
"headers": list_headers,
"path": path,
"path": strip_base_path(
path,
base_path=self.config["base_path"],
),
"raw_path": None,
"root_path": "",
"scheme": uq_headers.get("x-forwarded-proto", "https"),
Expand Down
10 changes: 5 additions & 5 deletions mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
handle_exclude_headers,
handle_multi_value_headers,
maybe_encode_body,
strip_api_gateway_path,
strip_base_path,
)
from mangum.types import (
Response,
Expand Down Expand Up @@ -93,9 +93,9 @@ def scope(self) -> Scope:
"http_version": "1.1",
"method": self.event["httpMethod"],
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
"path": strip_api_gateway_path(
"path": strip_base_path(
self.event["path"],
api_gateway_base_path=self.config["api_gateway_base_path"],
base_path=self.config["base_path"],
),
"raw_path": None,
"root_path": "",
Expand Down Expand Up @@ -175,9 +175,9 @@ def scope(self) -> Scope:
http_method = self.event["httpMethod"]
query_string = _encode_query_string_for_apigw(self.event)

path = strip_api_gateway_path(
path = strip_base_path(
path,
api_gateway_base_path=self.config["api_gateway_base_path"],
base_path=self.config["base_path"],
)
server = get_server_and_port(headers)
client = (source_ip, 0)
Expand Down
12 changes: 6 additions & 6 deletions mangum/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def get_server_and_port(headers: dict) -> Tuple[str, int]:
return server


def strip_api_gateway_path(path: str, *, api_gateway_base_path: str) -> str:
def strip_base_path(path: str, *, base_path: str) -> str:
if not path:
return "/"

if api_gateway_base_path and api_gateway_base_path != "/":
if not api_gateway_base_path.startswith("/"):
api_gateway_base_path = f"/{api_gateway_base_path}"
if path.startswith(api_gateway_base_path):
path = path[len(api_gateway_base_path) :]
if base_path and base_path != "/":
if not base_path.startswith("/"):
base_path = f"/{base_path}"
if path.startswith(base_path):
path = path[len(base_path) :]

return unquote(path)

Expand Down
1 change: 0 additions & 1 deletion mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ async def send(self, message: Message) -> None:
self.state is HTTPCycleState.RESPONSE
and message["type"] == "http.response.body"
):

body = message.get("body", b"")
more_body = message.get("more_body", False)
self.buffer.write(body)
Expand Down
2 changes: 0 additions & 2 deletions mangum/protocols/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,12 @@ async def run(self) -> None:
async def receive(self) -> Message:
"""Awaited by the application to receive ASGI `lifespan` events."""
if self.state is LifespanCycleState.CONNECTING:

# Connection established. The next event returned by the queue will be
# `lifespan.startup` to inform the application that the connection is
# ready to receive lfiespan messages.
self.state = LifespanCycleState.STARTUP

elif self.state is LifespanCycleState.STARTUP:

# Connection shutting down. The next event returned by the queue will be
# `lifespan.shutdown` to inform the application that the connection is now
# closing so that it may perform cleanup.
Expand Down
2 changes: 1 addition & 1 deletion mangum/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class Response(TypedDict):


class LambdaConfig(TypedDict):
api_gateway_base_path: str
base_path: str
text_mime_types: List[str]
exclude_headers: List[str]

Expand Down
78 changes: 77 additions & 1 deletion tests/handlers/test_alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501
"""
import urllib
from typing import Dict, List, Optional

import pytest
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_aws_alb_scope_real(
multi_value_headers,
)
example_context = {}
handler = ALB(event, example_context, {"api_gateway_base_path": "/"})
handler = ALB(event, example_context, {"base_path": "test"})

scope_path = path
if scope_path == "":
Expand Down Expand Up @@ -250,6 +251,81 @@ def test_aws_alb_scope_real(
assert handler.body == b""


@pytest.mark.parametrize(
"method,path,query_parameters,headers,req_body,body_base64_encoded,"
"query_string,scope_body,multi_value_headers",
[
("GET", "/test/hello/world", None, None, None, False, b"", None, False),
],
)
def test_aws_alb_base_path(
method,
path,
query_parameters,
headers,
req_body,
body_base64_encoded,
query_string,
scope_body,
multi_value_headers,
):
event = get_mock_aws_alb_event(
method,
path,
query_parameters,
headers,
req_body,
body_base64_encoded,
multi_value_headers,
)

async def app(scope, receive, send):
assert scope["type"] == "http"
assert scope["path"] == urllib.parse.unquote(event["path"])
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello world!"})

handler = Mangum(app, lifespan="off", base_path=None)
response = handler(event, {})

assert response == {
"body": "Hello world!",
"headers": {"content-type": "text/plain"},
"isBase64Encoded": False,
"statusCode": 200,
}

async def app(scope, receive, send):
assert scope["type"] == "http"
assert scope["path"] == urllib.parse.unquote(
event["path"][len(f"/{base_path}") :]
)
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello world!"})

base_path = "test"
handler = Mangum(app, lifespan="off", base_path=base_path)
response = handler(event, {})
assert response == {
"body": "Hello world!",
"headers": {"content-type": "text/plain"},
"isBase64Encoded": False,
"statusCode": 200,
}


@pytest.mark.parametrize("multi_value_headers_enabled", (True, False))
def test_aws_alb_set_cookies(multi_value_headers_enabled) -> None:
async def app(scope, receive, send):
Expand Down
73 changes: 71 additions & 2 deletions tests/handlers/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_aws_api_gateway_scope_basic():
"isBase64Encoded": False,
}
example_context = {}
handler = APIGateway(example_event, example_context, {"api_gateway_base_path": "/"})
handler = APIGateway(example_event, example_context, {"base_path": "/"})

assert type(handler.body) == bytes
assert handler.scope == {
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_aws_api_gateway_scope_real(
method, path, multi_value_query_parameters, req_body, body_base64_encoded
)
example_context = {}
handler = APIGateway(event, example_context, {"api_gateway_base_path": "/"})
handler = APIGateway(event, example_context, {"base_path": "/"})

scope_path = path
if scope_path == "":
Expand Down Expand Up @@ -249,6 +249,75 @@ def test_aws_api_gateway_scope_real(
assert handler.body == b""


@pytest.mark.parametrize(
"method,path,multi_value_query_parameters,req_body,body_base64_encoded,"
"query_string,scope_body",
[
("GET", "/test/hello", None, None, False, b"", None),
],
)
def test_aws_base_path(
method,
path,
multi_value_query_parameters,
req_body,
body_base64_encoded,
query_string,
scope_body,
):
event = get_mock_aws_api_gateway_event(
method, path, multi_value_query_parameters, req_body, body_base64_encoded
)

async def app(scope, receive, send):
assert scope["type"] == "http"
assert scope["path"] == urllib.parse.unquote(event["path"])
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello world!"})

handler = Mangum(app, lifespan="off", base_path=None)
response = handler(event, {})

assert response == {
"body": "Hello world!",
"headers": {"content-type": "text/plain"},
"multiValueHeaders": {},
"isBase64Encoded": False,
"statusCode": 200,
}

async def app(scope, receive, send):
assert scope["type"] == "http"
assert scope["path"] == urllib.parse.unquote(
event["path"][len(f"/{base_path}") :]
)
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello world!"})

base_path = "test"
handler = Mangum(app, lifespan="off", base_path=base_path)
response = handler(event, {})
assert response == {
"body": "Hello world!",
"headers": {"content-type": "text/plain"},
"multiValueHeaders": {},
"isBase64Encoded": False,
"statusCode": 200,
}


@pytest.mark.parametrize(
"method,path,multi_value_query_parameters,req_body,body_base64_encoded,"
"query_string,scope_body",
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __call__(self, *, status: int, headers: Headers, body: bytes) -> dict:

def test_custom_handler():
event = {"my-custom-key": 1}
handler = CustomHandler(event, {}, {"api_gateway_base_path": "/"})
handler = CustomHandler(event, {}, {"base_path": "/"})
assert type(handler.body) == bytes
assert handler.scope == {
"asgi": {"version": "3.0", "spec_version": "2.0"},
Expand Down
Loading