Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow base path config for ALB handler #291

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self,
app: ASGI,
lifespan: LifespanMode = "auto",
api_gateway_base_path: str = "/",
jordaneremieff marked this conversation as resolved.
Show resolved Hide resolved
base_path: str = "/",
custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
text_mime_types: Optional[List[str]] = None,
exclude_headers: Optional[List[str]] = None,
Expand All @@ -56,7 +56,7 @@ def __init__(
self.custom_handlers = custom_handlers or []
exclude_headers = exclude_headers or []
self.config = LambdaConfig(
api_gateway_base_path=api_gateway_base_path or "/",
base_path=base_path or "/",
text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES],
exclude_headers=[header.lower() for header in exclude_headers],
)
Expand Down
6 changes: 5 additions & 1 deletion mangum/handlers/alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
handle_base64_response_body,
handle_exclude_headers,
maybe_encode_body,
strip_base_path,
)
from mangum.types import (
Response,
Expand Down Expand Up @@ -130,7 +131,10 @@ def scope(self) -> Scope:
"method": http_method,
"http_version": "1.1",
"headers": list_headers,
"path": path,
"path": strip_base_path(
path,
base_path=self.config["base_path"],
),
"raw_path": None,
"root_path": "",
"scheme": uq_headers.get("x-forwarded-proto", "https"),
Expand Down
10 changes: 5 additions & 5 deletions mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
handle_exclude_headers,
handle_multi_value_headers,
maybe_encode_body,
strip_api_gateway_path,
strip_base_path,
)
from mangum.types import (
Response,
Expand Down Expand Up @@ -93,9 +93,9 @@ def scope(self) -> Scope:
"http_version": "1.1",
"method": self.event["httpMethod"],
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
"path": strip_api_gateway_path(
"path": strip_base_path(
self.event["path"],
api_gateway_base_path=self.config["api_gateway_base_path"],
base_path=self.config["base_path"],
),
"raw_path": None,
"root_path": "",
Expand Down Expand Up @@ -175,9 +175,9 @@ def scope(self) -> Scope:
http_method = self.event["httpMethod"]
query_string = _encode_query_string_for_apigw(self.event)

path = strip_api_gateway_path(
path = strip_base_path(
path,
api_gateway_base_path=self.config["api_gateway_base_path"],
base_path=self.config["base_path"],
)
server = get_server_and_port(headers)
client = (source_ip, 0)
Expand Down
12 changes: 6 additions & 6 deletions mangum/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def get_server_and_port(headers: dict) -> Tuple[str, int]:
return server


def strip_api_gateway_path(path: str, *, api_gateway_base_path: str) -> str:
def strip_base_path(path: str, *, base_path: str) -> str:
if not path:
return "/"

if api_gateway_base_path and api_gateway_base_path != "/":
if not api_gateway_base_path.startswith("/"):
api_gateway_base_path = f"/{api_gateway_base_path}"
if path.startswith(api_gateway_base_path):
path = path[len(api_gateway_base_path) :]
if base_path and base_path != "/":
if not base_path.startswith("/"):
base_path = f"/{base_path}"
if path.startswith(base_path):
path = path[len(base_path) :]

return unquote(path)

Expand Down
2 changes: 1 addition & 1 deletion mangum/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class Response(TypedDict):


class LambdaConfig(TypedDict):
api_gateway_base_path: str
base_path: str
text_mime_types: List[str]
exclude_headers: List[str]

Expand Down
78 changes: 77 additions & 1 deletion tests/handlers/test_alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501
"""
import urllib
from typing import Dict, List, Optional

import pytest
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_aws_alb_scope_real(
multi_value_headers,
)
example_context = {}
handler = ALB(event, example_context, {"api_gateway_base_path": "/"})
handler = ALB(event, example_context, {"base_path": "test"})

scope_path = path
if scope_path == "":
Expand Down Expand Up @@ -250,6 +251,81 @@ def test_aws_alb_scope_real(
assert handler.body == b""


@pytest.mark.parametrize(
"method,path,query_parameters,headers,req_body,body_base64_encoded,"
"query_string,scope_body,multi_value_headers",
[
("GET", "/test/hello/world", None, None, None, False, b"", None, False),
],
)
def test_aws_alb_base_path(
method,
path,
query_parameters,
headers,
req_body,
body_base64_encoded,
query_string,
scope_body,
multi_value_headers,
):
event = get_mock_aws_alb_event(
method,
path,
query_parameters,
headers,
req_body,
body_base64_encoded,
multi_value_headers,
)

async def app(scope, receive, send):
assert scope["type"] == "http"
assert scope["path"] == urllib.parse.unquote(event["path"])
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello world!"})

handler = Mangum(app, lifespan="off", base_path=None)
response = handler(event, {})

assert response == {
"body": "Hello world!",
"headers": {"content-type": "text/plain"},
"isBase64Encoded": False,
"statusCode": 200,
}

async def app(scope, receive, send):
assert scope["type"] == "http"
assert scope["path"] == urllib.parse.unquote(
event["path"][len(f"/{base_path}") :]
)
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello world!"})

base_path = "test"
handler = Mangum(app, lifespan="off", base_path=base_path)
response = handler(event, {})
assert response == {
"body": "Hello world!",
"headers": {"content-type": "text/plain"},
"isBase64Encoded": False,
"statusCode": 200,
}


@pytest.mark.parametrize("multi_value_headers_enabled", (True, False))
def test_aws_alb_set_cookies(multi_value_headers_enabled) -> None:
async def app(scope, receive, send):
Expand Down
14 changes: 7 additions & 7 deletions tests/handlers/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_aws_api_gateway_scope_basic():
"isBase64Encoded": False,
}
example_context = {}
handler = APIGateway(example_event, example_context, {"api_gateway_base_path": "/"})
handler = APIGateway(example_event, example_context, {"base_path": "/"})

assert type(handler.body) == bytes
assert handler.scope == {
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_aws_api_gateway_scope_real(
method, path, multi_value_query_parameters, req_body, body_base64_encoded
)
example_context = {}
handler = APIGateway(event, example_context, {"api_gateway_base_path": "/"})
handler = APIGateway(event, example_context, {"base_path": "/"})

scope_path = path
if scope_path == "":
Expand Down Expand Up @@ -256,7 +256,7 @@ def test_aws_api_gateway_scope_real(
("GET", "/test/hello", None, None, False, b"", None),
],
)
def test_aws_api_gateway_base_path(
def test_aws_base_path(
method,
path,
multi_value_query_parameters,
Expand All @@ -281,7 +281,7 @@ async def app(scope, receive, send):
)
await send({"type": "http.response.body", "body": b"Hello world!"})

handler = Mangum(app, lifespan="off", api_gateway_base_path=None)
handler = Mangum(app, lifespan="off", base_path=None)
response = handler(event, {})

assert response == {
Expand All @@ -295,7 +295,7 @@ async def app(scope, receive, send):
async def app(scope, receive, send):
assert scope["type"] == "http"
assert scope["path"] == urllib.parse.unquote(
event["path"][len(f"/{api_gateway_base_path}") :]
event["path"][len(f"/{base_path}") :]
)
await send(
{
Expand All @@ -306,8 +306,8 @@ async def app(scope, receive, send):
)
await send({"type": "http.response.body", "body": b"Hello world!"})

api_gateway_base_path = "test"
handler = Mangum(app, lifespan="off", api_gateway_base_path=api_gateway_base_path)
base_path = "test"
handler = Mangum(app, lifespan="off", base_path=base_path)
response = handler(event, {})
assert response == {
"body": "Hello world!",
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __call__(self, *, status: int, headers: Headers, body: bytes) -> dict:

def test_custom_handler():
event = {"my-custom-key": 1}
handler = CustomHandler(event, {}, {"api_gateway_base_path": "/"})
handler = CustomHandler(event, {}, {"base_path": "/"})
assert type(handler.body) == bytes
assert handler.scope == {
"asgi": {"version": "3.0", "spec_version": "2.0"},
Expand Down
12 changes: 6 additions & 6 deletions tests/handlers/test_http_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_aws_http_gateway_scope_basic_v1():

example_context = {}
handler = HTTPGateway(
example_event, example_context, {"api_gateway_base_path": "/"}
example_event, example_context, {"base_path": "/"}
)

assert type(handler.body) == bytes
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_aws_http_gateway_scope_v1_only_non_multi_headers():
del example_event["multiValueQueryStringParameters"]
example_context = {}
handler = HTTPGateway(
example_event, example_context, {"api_gateway_base_path": "/"}
example_event, example_context, {"base_path": "/"}
)
assert handler.scope["query_string"] == b"hello=world"

Expand All @@ -245,7 +245,7 @@ def test_aws_http_gateway_scope_v1_no_headers():
del example_event["queryStringParameters"]
example_context = {}
handler = HTTPGateway(
example_event, example_context, {"api_gateway_base_path": "/"}
example_event, example_context, {"base_path": "/"}
)
assert handler.scope["query_string"] == b""

Expand Down Expand Up @@ -305,7 +305,7 @@ def test_aws_http_gateway_scope_basic_v2():
}
example_context = {}
handler = HTTPGateway(
example_event, example_context, {"api_gateway_base_path": "/"}
example_event, example_context, {"base_path": "/"}
)

assert type(handler.body) == bytes
Expand Down Expand Up @@ -363,7 +363,7 @@ def test_aws_http_gateway_scope_real_v1(
method, path, query_parameters, req_body, body_base64_encoded
)
example_context = {}
handler = HTTPGateway(event, example_context, {"api_gateway_base_path": "/"})
handler = HTTPGateway(event, example_context, {"base_path": "/"})

scope_path = path
if scope_path == "":
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_aws_http_gateway_scope_real_v2(
method, path, query_parameters, req_body, body_base64_encoded
)
example_context = {}
handler = HTTPGateway(event, example_context, {"api_gateway_base_path": "/"})
handler = HTTPGateway(event, example_context, {"base_path": "/"})

scope_path = path
if scope_path == "":
Expand Down
4 changes: 2 additions & 2 deletions tests/handlers/test_lambda_at_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_aws_cf_lambda_at_edge_scope_basic():
}
example_context = {}
handler = LambdaAtEdge(
example_event, example_context, {"api_gateway_base_path": "/"}
example_event, example_context, {"base_path": "/"}
)

assert type(handler.body) == bytes
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_aws_api_gateway_scope_real(
method, path, multi_value_query_parameters, req_body, body_base64_encoded
)
example_context = {}
handler = LambdaAtEdge(event, example_context, {"api_gateway_base_path": "/"})
handler = LambdaAtEdge(event, example_context, {"base_path": "/"})

assert handler.scope == {
"asgi": {"version": "3.0", "spec_version": "2.0"},
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async def app(scope, receive, send):
def test_default_settings():
handler = Mangum(app)
assert handler.lifespan == "auto"
assert handler.config["api_gateway_base_path"] == "/"
assert handler.config["base_path"] == "/"
assert sorted(handler.config["text_mime_types"]) == sorted(DEFAULT_TEXT_MIME_TYPES)
assert handler.config["exclude_headers"] == []

Expand Down