diff --git a/docs/adapter.md b/docs/adapter.md index 04b179e..8fa466f 100644 --- a/docs/adapter.md +++ b/docs/adapter.md @@ -6,7 +6,7 @@ The heart of Mangum is the adapter class. It is a configurable wrapper that allo handler = Mangum( app, lifespan="auto", - api_gateway_base_path=None, + base_path=None, custom_handlers=None, text_mime_types=None, ) diff --git a/mangum/adapter.py b/mangum/adapter.py index bb99cfb..97124ae 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -1,4 +1,5 @@ import logging +import warnings from itertools import chain from contextlib import ExitStack from typing import List, Optional, Type @@ -42,6 +43,7 @@ def __init__( app: ASGI, lifespan: LifespanMode = "auto", api_gateway_base_path: str = "/", + base_path: str = "/", custom_handlers: Optional[List[Type[LambdaHandler]]] = None, text_mime_types: Optional[List[str]] = None, exclude_headers: Optional[List[str]] = None, @@ -55,8 +57,16 @@ def __init__( self.lifespan = lifespan self.custom_handlers = custom_handlers or [] exclude_headers = exclude_headers or [] + if api_gateway_base_path and api_gateway_base_path != "/": + warnings.warn( + "`api_gateway_base_path` parameter is deprecated and will be " + "removed in future versions. Please use `base_path` instead.", + DeprecationWarning, + ) + base_path = api_gateway_base_path + self.config = LambdaConfig( - api_gateway_base_path=api_gateway_base_path or "/", + base_path=base_path or "/", text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES], exclude_headers=[header.lower() for header in exclude_headers], ) diff --git a/mangum/handlers/alb.py b/mangum/handlers/alb.py index 875c4ee..95c5a3a 100644 --- a/mangum/handlers/alb.py +++ b/mangum/handlers/alb.py @@ -7,6 +7,7 @@ handle_base64_response_body, handle_exclude_headers, maybe_encode_body, + strip_base_path, ) from mangum.types import ( Response, @@ -104,7 +105,6 @@ def body(self) -> bytes: @property def scope(self) -> Scope: - headers = transform_headers(self.event) list_headers = [list(x) for x in headers] # Unique headers. If there are duplicates, it will use the last defined. @@ -130,7 +130,10 @@ def scope(self) -> Scope: "method": http_method, "http_version": "1.1", "headers": list_headers, - "path": path, + "path": strip_base_path( + path, + base_path=self.config["base_path"], + ), "raw_path": None, "root_path": "", "scheme": uq_headers.get("x-forwarded-proto", "https"), diff --git a/mangum/handlers/api_gateway.py b/mangum/handlers/api_gateway.py index d9b30c0..989a86d 100644 --- a/mangum/handlers/api_gateway.py +++ b/mangum/handlers/api_gateway.py @@ -7,7 +7,7 @@ handle_exclude_headers, handle_multi_value_headers, maybe_encode_body, - strip_api_gateway_path, + strip_base_path, ) from mangum.types import ( Response, @@ -93,9 +93,9 @@ def scope(self) -> Scope: "http_version": "1.1", "method": self.event["httpMethod"], "headers": [[k.encode(), v.encode()] for k, v in headers.items()], - "path": strip_api_gateway_path( + "path": strip_base_path( self.event["path"], - api_gateway_base_path=self.config["api_gateway_base_path"], + base_path=self.config["base_path"], ), "raw_path": None, "root_path": "", @@ -175,9 +175,9 @@ def scope(self) -> Scope: http_method = self.event["httpMethod"] query_string = _encode_query_string_for_apigw(self.event) - path = strip_api_gateway_path( + path = strip_base_path( path, - api_gateway_base_path=self.config["api_gateway_base_path"], + base_path=self.config["base_path"], ) server = get_server_and_port(headers) client = (source_ip, 0) diff --git a/mangum/handlers/utils.py b/mangum/handlers/utils.py index 7e3e7b3..3deed1f 100644 --- a/mangum/handlers/utils.py +++ b/mangum/handlers/utils.py @@ -26,15 +26,15 @@ def get_server_and_port(headers: dict) -> Tuple[str, int]: return server -def strip_api_gateway_path(path: str, *, api_gateway_base_path: str) -> str: +def strip_base_path(path: str, *, base_path: str) -> str: if not path: return "/" - if api_gateway_base_path and api_gateway_base_path != "/": - if not api_gateway_base_path.startswith("/"): - api_gateway_base_path = f"/{api_gateway_base_path}" - if path.startswith(api_gateway_base_path): - path = path[len(api_gateway_base_path) :] + if base_path and base_path != "/": + if not base_path.startswith("/"): + base_path = f"/{base_path}" + if path.startswith(base_path): + path = path[len(base_path) :] return unquote(path) diff --git a/mangum/protocols/http.py b/mangum/protocols/http.py index b43b11b..942bb24 100644 --- a/mangum/protocols/http.py +++ b/mangum/protocols/http.py @@ -93,7 +93,6 @@ async def send(self, message: Message) -> None: self.state is HTTPCycleState.RESPONSE and message["type"] == "http.response.body" ): - body = message.get("body", b"") more_body = message.get("more_body", False) self.buffer.write(body) diff --git a/mangum/protocols/lifespan.py b/mangum/protocols/lifespan.py index ca87392..5b27b0b 100644 --- a/mangum/protocols/lifespan.py +++ b/mangum/protocols/lifespan.py @@ -98,14 +98,12 @@ async def run(self) -> None: async def receive(self) -> Message: """Awaited by the application to receive ASGI `lifespan` events.""" if self.state is LifespanCycleState.CONNECTING: - # Connection established. The next event returned by the queue will be # `lifespan.startup` to inform the application that the connection is # ready to receive lfiespan messages. self.state = LifespanCycleState.STARTUP elif self.state is LifespanCycleState.STARTUP: - # Connection shutting down. The next event returned by the queue will be # `lifespan.shutdown` to inform the application that the connection is now # closing so that it may perform cleanup. diff --git a/mangum/types.py b/mangum/types.py index 0ff436c..9796149 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -115,7 +115,7 @@ class Response(TypedDict): class LambdaConfig(TypedDict): - api_gateway_base_path: str + base_path: str text_mime_types: List[str] exclude_headers: List[str] diff --git a/tests/handlers/test_alb.py b/tests/handlers/test_alb.py index e75d2d9..87d24cb 100644 --- a/tests/handlers/test_alb.py +++ b/tests/handlers/test_alb.py @@ -3,6 +3,7 @@ 1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html 2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501 """ +import urllib from typing import Dict, List, Optional import pytest @@ -199,7 +200,7 @@ def test_aws_alb_scope_real( multi_value_headers, ) example_context = {} - handler = ALB(event, example_context, {"api_gateway_base_path": "/"}) + handler = ALB(event, example_context, {"base_path": "test"}) scope_path = path if scope_path == "": @@ -250,6 +251,81 @@ def test_aws_alb_scope_real( assert handler.body == b"" +@pytest.mark.parametrize( + "method,path,query_parameters,headers,req_body,body_base64_encoded," + "query_string,scope_body,multi_value_headers", + [ + ("GET", "/test/hello/world", None, None, None, False, b"", None, False), + ], +) +def test_aws_alb_base_path( + method, + path, + query_parameters, + headers, + req_body, + body_base64_encoded, + query_string, + scope_body, + multi_value_headers, +): + event = get_mock_aws_alb_event( + method, + path, + query_parameters, + headers, + req_body, + body_base64_encoded, + multi_value_headers, + ) + + async def app(scope, receive, send): + assert scope["type"] == "http" + assert scope["path"] == urllib.parse.unquote(event["path"]) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello world!"}) + + handler = Mangum(app, lifespan="off", base_path=None) + response = handler(event, {}) + + assert response == { + "body": "Hello world!", + "headers": {"content-type": "text/plain"}, + "isBase64Encoded": False, + "statusCode": 200, + } + + async def app(scope, receive, send): + assert scope["type"] == "http" + assert scope["path"] == urllib.parse.unquote( + event["path"][len(f"/{base_path}") :] + ) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello world!"}) + + base_path = "test" + handler = Mangum(app, lifespan="off", base_path=base_path) + response = handler(event, {}) + assert response == { + "body": "Hello world!", + "headers": {"content-type": "text/plain"}, + "isBase64Encoded": False, + "statusCode": 200, + } + + @pytest.mark.parametrize("multi_value_headers_enabled", (True, False)) def test_aws_alb_set_cookies(multi_value_headers_enabled) -> None: async def app(scope, receive, send): diff --git a/tests/handlers/test_api_gateway.py b/tests/handlers/test_api_gateway.py index e2458c2..d8d6724 100644 --- a/tests/handlers/test_api_gateway.py +++ b/tests/handlers/test_api_gateway.py @@ -103,7 +103,7 @@ def test_aws_api_gateway_scope_basic(): "isBase64Encoded": False, } example_context = {} - handler = APIGateway(example_event, example_context, {"api_gateway_base_path": "/"}) + handler = APIGateway(example_event, example_context, {"base_path": "/"}) assert type(handler.body) == bytes assert handler.scope == { @@ -200,7 +200,7 @@ def test_aws_api_gateway_scope_real( method, path, multi_value_query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = APIGateway(event, example_context, {"api_gateway_base_path": "/"}) + handler = APIGateway(event, example_context, {"base_path": "/"}) scope_path = path if scope_path == "": @@ -249,6 +249,75 @@ def test_aws_api_gateway_scope_real( assert handler.body == b"" +@pytest.mark.parametrize( + "method,path,multi_value_query_parameters,req_body,body_base64_encoded," + "query_string,scope_body", + [ + ("GET", "/test/hello", None, None, False, b"", None), + ], +) +def test_aws_base_path( + method, + path, + multi_value_query_parameters, + req_body, + body_base64_encoded, + query_string, + scope_body, +): + event = get_mock_aws_api_gateway_event( + method, path, multi_value_query_parameters, req_body, body_base64_encoded + ) + + async def app(scope, receive, send): + assert scope["type"] == "http" + assert scope["path"] == urllib.parse.unquote(event["path"]) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello world!"}) + + handler = Mangum(app, lifespan="off", base_path=None) + response = handler(event, {}) + + assert response == { + "body": "Hello world!", + "headers": {"content-type": "text/plain"}, + "multiValueHeaders": {}, + "isBase64Encoded": False, + "statusCode": 200, + } + + async def app(scope, receive, send): + assert scope["type"] == "http" + assert scope["path"] == urllib.parse.unquote( + event["path"][len(f"/{base_path}") :] + ) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello world!"}) + + base_path = "test" + handler = Mangum(app, lifespan="off", base_path=base_path) + response = handler(event, {}) + assert response == { + "body": "Hello world!", + "headers": {"content-type": "text/plain"}, + "multiValueHeaders": {}, + "isBase64Encoded": False, + "statusCode": 200, + } + + @pytest.mark.parametrize( "method,path,multi_value_query_parameters,req_body,body_base64_encoded," "query_string,scope_body", diff --git a/tests/handlers/test_custom.py b/tests/handlers/test_custom.py index c330bb6..129f7d0 100644 --- a/tests/handlers/test_custom.py +++ b/tests/handlers/test_custom.py @@ -51,7 +51,7 @@ def __call__(self, *, status: int, headers: Headers, body: bytes) -> dict: def test_custom_handler(): event = {"my-custom-key": 1} - handler = CustomHandler(event, {}, {"api_gateway_base_path": "/"}) + handler = CustomHandler(event, {}, {"base_path": "/"}) assert type(handler.body) == bytes assert handler.scope == { "asgi": {"version": "3.0", "spec_version": "2.0"}, diff --git a/tests/handlers/test_http_gateway.py b/tests/handlers/test_http_gateway.py index 6549042..870a21d 100644 --- a/tests/handlers/test_http_gateway.py +++ b/tests/handlers/test_http_gateway.py @@ -195,9 +195,7 @@ def test_aws_http_gateway_scope_basic_v1(): } example_context = {} - handler = HTTPGateway( - example_event, example_context, {"api_gateway_base_path": "/"} - ) + handler = HTTPGateway(example_event, example_context, {"base_path": "/"}) assert type(handler.body) == bytes assert handler.scope == { @@ -228,9 +226,7 @@ def test_aws_http_gateway_scope_v1_only_non_multi_headers(): ) del example_event["multiValueQueryStringParameters"] example_context = {} - handler = HTTPGateway( - example_event, example_context, {"api_gateway_base_path": "/"} - ) + handler = HTTPGateway(example_event, example_context, {"base_path": "/"}) assert handler.scope["query_string"] == b"hello=world" @@ -244,9 +240,7 @@ def test_aws_http_gateway_scope_v1_no_headers(): del example_event["multiValueQueryStringParameters"] del example_event["queryStringParameters"] example_context = {} - handler = HTTPGateway( - example_event, example_context, {"api_gateway_base_path": "/"} - ) + handler = HTTPGateway(example_event, example_context, {"base_path": "/"}) assert handler.scope["query_string"] == b"" @@ -304,9 +298,7 @@ def test_aws_http_gateway_scope_basic_v2(): "stageVariables": {"stageVariable1": "value1", "stageVariable2": "value2"}, } example_context = {} - handler = HTTPGateway( - example_event, example_context, {"api_gateway_base_path": "/"} - ) + handler = HTTPGateway(example_event, example_context, {"base_path": "/"}) assert type(handler.body) == bytes assert handler.scope == { @@ -363,7 +355,7 @@ def test_aws_http_gateway_scope_real_v1( method, path, query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = HTTPGateway(event, example_context, {"api_gateway_base_path": "/"}) + handler = HTTPGateway(event, example_context, {"base_path": "/"}) scope_path = path if scope_path == "": @@ -429,7 +421,7 @@ def test_aws_http_gateway_scope_real_v2( method, path, query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = HTTPGateway(event, example_context, {"api_gateway_base_path": "/"}) + handler = HTTPGateway(event, example_context, {"base_path": "/"}) scope_path = path if scope_path == "": diff --git a/tests/handlers/test_lambda_at_edge.py b/tests/handlers/test_lambda_at_edge.py index 563e144..a638855 100644 --- a/tests/handlers/test_lambda_at_edge.py +++ b/tests/handlers/test_lambda_at_edge.py @@ -134,9 +134,7 @@ def test_aws_cf_lambda_at_edge_scope_basic(): ] } example_context = {} - handler = LambdaAtEdge( - example_event, example_context, {"api_gateway_base_path": "/"} - ) + handler = LambdaAtEdge(example_event, example_context, {"base_path": "/"}) assert type(handler.body) == bytes assert handler.scope == { @@ -225,7 +223,7 @@ def test_aws_api_gateway_scope_real( method, path, multi_value_query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = LambdaAtEdge(event, example_context, {"api_gateway_base_path": "/"}) + handler = LambdaAtEdge(event, example_context, {"base_path": "/"}) assert handler.scope == { "asgi": {"version": "3.0", "spec_version": "2.0"}, diff --git a/tests/test_adapter.py b/tests/test_adapter.py index de36049..9ae9aae 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -12,7 +12,7 @@ async def app(scope, receive, send): def test_default_settings(): handler = Mangum(app) assert handler.lifespan == "auto" - assert handler.config["api_gateway_base_path"] == "/" + assert handler.config["base_path"] == "/" assert sorted(handler.config["text_mime_types"]) == sorted(DEFAULT_TEXT_MIME_TYPES) assert handler.config["exclude_headers"] == []