diff --git a/.gitignore b/.gitignore index 2ce352d..563750a 100644 --- a/.gitignore +++ b/.gitignore @@ -105,5 +105,7 @@ venv.bak/ # IDE Settings .idea/ +.vscode +.devcontainer .DS_Store diff --git a/docs/adapter.md b/docs/adapter.md index 6801fb2..04b179e 100644 --- a/docs/adapter.md +++ b/docs/adapter.md @@ -7,6 +7,8 @@ handler = Mangum( app, lifespan="auto", api_gateway_base_path=None, + custom_handlers=None, + text_mime_types=None, ) ``` diff --git a/mangum/adapter.py b/mangum/adapter.py index 7a93735..31d2d1f 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -26,6 +26,15 @@ LambdaAtEdge, ] +DEFAULT_TEXT_MIME_TYPES: List[str] = [ + "text/", + "application/json", + "application/javascript", + "application/xml", + "application/vnd.api+json", + "application/vnd.oai.openapi", +] + class Mangum: def __init__( @@ -34,6 +43,7 @@ def __init__( lifespan: LifespanMode = "auto", api_gateway_base_path: str = "/", custom_handlers: Optional[List[Type[LambdaHandler]]] = None, + text_mime_types: Optional[List[str]] = None, ) -> None: if lifespan not in ("auto", "on", "off"): raise ConfigurationError( @@ -42,24 +52,22 @@ def __init__( self.app = app self.lifespan = lifespan - self.api_gateway_base_path = api_gateway_base_path or "/" - self.config = LambdaConfig(api_gateway_base_path=self.api_gateway_base_path) self.custom_handlers = custom_handlers or [] + self.config = LambdaConfig( + api_gateway_base_path=api_gateway_base_path or "/", + text_mime_types=text_mime_types or [*DEFAULT_TEXT_MIME_TYPES], + ) def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler: for handler_cls in chain(self.custom_handlers, HANDLERS): if handler_cls.infer(event, context, self.config): - handler = handler_cls(event, context, self.config) - break - else: - raise RuntimeError( # pragma: no cover - "The adapter was unable to infer a handler to use for the event. This " - "is likely related to how the Lambda function was invoked. (Are you " - "testing locally? Make sure the request payload is valid for a " - "supported handler.)" - ) - - return handler + return handler_cls(event, context, self.config) + raise RuntimeError( # pragma: no cover + "The adapter was unable to infer a handler to use for the event. This " + "is likely related to how the Lambda function was invoked. (Are you " + "testing locally? Make sure the request payload is valid for a " + "supported handler.)" + ) def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict: handler = self.infer(event, context) diff --git a/mangum/handlers/alb.py b/mangum/handlers/alb.py index 02ef0a9..41378ed 100644 --- a/mangum/handlers/alb.py +++ b/mangum/handlers/alb.py @@ -153,7 +153,7 @@ def __call__(self, response: Response) -> dict: finalized_headers = case_mutated_headers(multi_value_headers) finalized_body, is_base64_encoded = handle_base64_response_body( - response["body"], finalized_headers + response["body"], finalized_headers, self.config["text_mime_types"] ) out = { diff --git a/mangum/handlers/api_gateway.py b/mangum/handlers/api_gateway.py index 9bca9e2..bd58a7d 100644 --- a/mangum/handlers/api_gateway.py +++ b/mangum/handlers/api_gateway.py @@ -115,7 +115,7 @@ def __call__(self, response: Response) -> dict: response["headers"] ) finalized_body, is_base64_encoded = handle_base64_response_body( - response["body"], finalized_headers + response["body"], finalized_headers, self.config["text_mime_types"] ) return { @@ -204,7 +204,7 @@ def __call__(self, response: Response) -> dict: finalized_headers["content-type"] = "application/json" finalized_body, is_base64_encoded = handle_base64_response_body( - response["body"], finalized_headers + response["body"], finalized_headers, self.config["text_mime_types"] ) response_out = { "statusCode": response["status"], @@ -221,7 +221,7 @@ def __call__(self, response: Response) -> dict: response["headers"] ) finalized_body, is_base64_encoded = handle_base64_response_body( - response["body"], finalized_headers + response["body"], finalized_headers, self.config["text_mime_types"] ) return { "statusCode": response["status"], diff --git a/mangum/handlers/lambda_at_edge.py b/mangum/handlers/lambda_at_edge.py index 6d307f0..6737967 100644 --- a/mangum/handlers/lambda_at_edge.py +++ b/mangum/handlers/lambda_at_edge.py @@ -79,7 +79,7 @@ def scope(self) -> Scope: def __call__(self, response: Response) -> dict: multi_value_headers, _ = handle_multi_value_headers(response["headers"]) response_body, is_base64_encoded = handle_base64_response_body( - response["body"], multi_value_headers + response["body"], multi_value_headers, self.config["text_mime_types"] ) finalized_headers: Dict[str, List[Dict[str, str]]] = { key.decode().lower(): [{"key": key.decode().lower(), "value": val.decode()}] diff --git a/mangum/handlers/utils.py b/mangum/handlers/utils.py index 91b84c1..c1cce0b 100644 --- a/mangum/handlers/utils.py +++ b/mangum/handlers/utils.py @@ -5,16 +5,6 @@ from mangum.types import Headers -DEFAULT_TEXT_MIME_TYPES = [ - "text/", - "application/json", - "application/javascript", - "application/xml", - "application/vnd.api+json", - "application/vnd.oai.openapi", -] - - def maybe_encode_body(body: Union[str, bytes], *, is_base64: bool) -> bytes: body = body or b"" if is_base64: @@ -71,12 +61,14 @@ def handle_multi_value_headers( def handle_base64_response_body( - body: bytes, headers: Dict[str, str] + body: bytes, + headers: Dict[str, str], + text_mime_types: List[str], ) -> Tuple[str, bool]: is_base64_encoded = False output_body = "" if body != b"": - for text_mime_type in DEFAULT_TEXT_MIME_TYPES: + for text_mime_type in text_mime_types: if text_mime_type in headers.get("content-type", ""): try: output_body = body.decode() diff --git a/mangum/types.py b/mangum/types.py index 20e8095..b50b0b2 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -116,6 +116,7 @@ class Response(TypedDict): class LambdaConfig(TypedDict): api_gateway_base_path: str + text_mime_types: List[str] class LambdaHandler(Protocol): diff --git a/tests/handlers/test_alb.py b/tests/handlers/test_alb.py index 6213088..3804f9d 100644 --- a/tests/handlers/test_alb.py +++ b/tests/handlers/test_alb.py @@ -331,3 +331,44 @@ async def app(scope, receive, send): "headers": {"content-type": content_type.decode()}, "body": res_body, } + + +def test_aws_alb_response_extra_mime_types(): + content_type = b"application/x-yaml" + utf_res_body = "name: 'John Doe'" + raw_res_body = utf_res_body.encode() + b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw==" + + async def app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", content_type]], + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = get_mock_aws_alb_event("GET", "/test", {}, None, None, False, False) + + # Test default behavior + handler = Mangum(app, lifespan="off") + response = handler(event, {}) + assert content_type.decode() not in handler.config["text_mime_types"] + assert response == { + "statusCode": 200, + "isBase64Encoded": True, + "headers": {"content-type": content_type.decode()}, + "body": b64_res_body, + } + + # Test with modified text mime types + handler = Mangum(app, lifespan="off") + handler.config["text_mime_types"].append(content_type.decode()) + response = handler(event, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": {"content-type": content_type.decode()}, + "body": utf_res_body, + } diff --git a/tests/handlers/test_api_gateway.py b/tests/handlers/test_api_gateway.py index 4c7cf0a..1231bb0 100644 --- a/tests/handlers/test_api_gateway.py +++ b/tests/handlers/test_api_gateway.py @@ -358,3 +358,46 @@ async def app(scope, receive, send): "multiValueHeaders": {}, "body": res_body, } + + +def test_aws_api_gateway_response_extra_mime_types(): + content_type = b"application/x-yaml" + utf_res_body = "name: 'John Doe'" + raw_res_body = utf_res_body.encode() + b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw==" + + async def app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", content_type]], + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = get_mock_aws_api_gateway_event("POST", "/test", {}, None, False) + + # Test default behavior + handler = Mangum(app, lifespan="off") + response = handler(event, {}) + assert content_type.decode() not in handler.config["text_mime_types"] + assert response == { + "statusCode": 200, + "isBase64Encoded": True, + "headers": {"content-type": content_type.decode()}, + "multiValueHeaders": {}, + "body": b64_res_body, + } + + # Test with modified text mime types + handler = Mangum(app, lifespan="off") + handler.config["text_mime_types"].append(content_type.decode()) + response = handler(event, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": {"content-type": content_type.decode()}, + "multiValueHeaders": {}, + "body": utf_res_body, + } diff --git a/tests/handlers/test_http_gateway.py b/tests/handlers/test_http_gateway.py index df42835..6549042 100644 --- a/tests/handlers/test_http_gateway.py +++ b/tests/handlers/test_http_gateway.py @@ -595,3 +595,95 @@ async def app(scope, receive, send): "headers": {"content-type": content_type.decode()}, "body": res_body, } + + +def test_aws_http_gateway_response_v1_extra_mime_types(): + content_type = b"application/x-yaml" + utf_res_body = "name: 'John Doe'" + raw_res_body = utf_res_body.encode() + b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw==" + + async def app(scope, receive, send): + headers = [] + if content_type is not None: + headers.append([b"content-type", content_type]) + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": headers, + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = get_mock_aws_http_gateway_event_v1("POST", "/test", {}, None, False) + + # Test default behavior + handler = Mangum(app, lifespan="off") + response = handler(event, {}) + assert content_type.decode() not in handler.config["text_mime_types"] + assert response == { + "statusCode": 200, + "isBase64Encoded": True, + "headers": {"content-type": content_type.decode()}, + "multiValueHeaders": {}, + "body": b64_res_body, + } + + # Test with modified text mime types + handler = Mangum(app, lifespan="off") + handler.config["text_mime_types"].append(content_type.decode()) + response = handler(event, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": {"content-type": content_type.decode()}, + "multiValueHeaders": {}, + "body": utf_res_body, + } + + +def test_aws_http_gateway_response_v2_extra_mime_types(): + content_type = b"application/x-yaml" + utf_res_body = "name: 'John Doe'" + raw_res_body = utf_res_body.encode() + b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw==" + + async def app(scope, receive, send): + headers = [] + if content_type is not None: + headers.append([b"content-type", content_type]) + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": headers, + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = get_mock_aws_http_gateway_event_v2("POST", "/test", {}, None, False) + + # Test default behavior + handler = Mangum(app, lifespan="off") + response = handler(event, {}) + assert content_type.decode() not in handler.config["text_mime_types"] + assert response == { + "statusCode": 200, + "isBase64Encoded": True, + "headers": {"content-type": content_type.decode()}, + "body": b64_res_body, + } + + # Test with modified text mime types + handler = Mangum(app, lifespan="off") + handler.config["text_mime_types"].append(content_type.decode()) + response = handler(event, {}) + assert response == { + "statusCode": 200, + "isBase64Encoded": False, + "headers": {"content-type": content_type.decode()}, + "body": utf_res_body, + } diff --git a/tests/handlers/test_lambda_at_edge.py b/tests/handlers/test_lambda_at_edge.py index 47a53f4..ffeb9bc 100644 --- a/tests/handlers/test_lambda_at_edge.py +++ b/tests/handlers/test_lambda_at_edge.py @@ -297,3 +297,48 @@ async def app(scope, receive, send): }, "body": res_body, } + + +def test_aws_lambda_at_edge_response_extra_mime_types(): + content_type = b"application/x-yaml" + utf_res_body = "name: 'John Doe'" + raw_res_body = utf_res_body.encode() + b64_res_body = "bmFtZTogJ0pvaG4gRG9lJw==" + + async def app(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", content_type]], + } + ) + await send({"type": "http.response.body", "body": raw_res_body}) + + event = mock_lambda_at_edge_event("POST", "/test", {}, None, False) + + # Test default behavior + handler = Mangum(app, lifespan="off") + response = handler(event, {}) + assert content_type.decode() not in handler.config["text_mime_types"] + assert response == { + "status": 200, + "isBase64Encoded": True, + "headers": { + "content-type": [{"key": "content-type", "value": content_type.decode()}] + }, + "body": b64_res_body, + } + + # Test with modified text mime types + handler = Mangum(app, lifespan="off") + handler.config["text_mime_types"].append(content_type.decode()) + response = handler(event, {}) + assert response == { + "status": 200, + "isBase64Encoded": False, + "headers": { + "content-type": [{"key": "content-type", "value": content_type.decode()}] + }, + "body": utf_res_body, + } diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 6a014b3..6b50fd6 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,7 +1,8 @@ import pytest -from mangum.exceptions import ConfigurationError from mangum import Mangum +from mangum.adapter import DEFAULT_TEXT_MIME_TYPES +from mangum.exceptions import ConfigurationError async def app(scope, receive, send): @@ -11,7 +12,8 @@ async def app(scope, receive, send): def test_default_settings(): handler = Mangum(app) assert handler.lifespan == "auto" - assert handler.api_gateway_base_path == "/" + assert handler.config["api_gateway_base_path"] == "/" + assert sorted(handler.config["text_mime_types"]) == sorted(DEFAULT_TEXT_MIME_TYPES) @pytest.mark.parametrize(