diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 19a1acdd..7d28d50b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,28 +11,28 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.6', '3.7', '3.8', '3.9', '3.10'] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - uses: actions/cache@v2 - name: Configure pip caching - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} - restore-keys: | - ${{ runner.os }}-pip- - - name: Install dependencies - run: | - pip install -U -r requirements.txt - - name: Run tests - run: | - scripts/test - - name: Run linters - run: | - scripts/lint - - name: Run codecov - run: codecov + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - uses: actions/cache@v2 + name: Configure pip caching + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install dependencies + run: | + pip install -U -r requirements.txt + - name: Run tests + run: | + scripts/test + - name: Run linters + run: | + scripts/lint + - name: Run codecov + run: codecov diff --git a/README.md b/README.md index e871acfb..af6ba7a5 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,13 @@ PyPI - Python Version -Mangum is an adapter for using [ASGI](https://asgi.readthedocs.io/en/latest/) applications with AWS Lambda & API Gateway. It is intended to provide an easy-to-use, configurable wrapper for any ASGI application deployed in an AWS Lambda function to handle API Gateway requests and responses. +Mangum is an adapter for running [ASGI](https://asgi.readthedocs.io/en/latest/) applications in AWS Lambda to handle API Gateway, ALB, and Lambda@Edge events. ***Documentation***: https://mangum.io/ ## Features -- API Gateway support for [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) and [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) APIs. +- Event handlers for API Gateway [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) and [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) APIs, [Application Load Balancer](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html), and [CloudFront Lambda@Edge](https://docs.aws.amazon.com/lambda/latest/dg/lambda-edge.html). - Compatibility with ASGI application frameworks, such as [Starlette](https://www.starlette.io/), [FastAPI](https://fastapi.tiangolo.com/), and [Quart](https://pgjones.gitlab.io/quart/). @@ -26,7 +26,7 @@ Mangum is an adapter for using [ASGI](https://asgi.readthedocs.io/en/latest/) ap ## Requirements -Python 3.6+ +Python 3.7+ ## Installation @@ -53,7 +53,7 @@ async def app(scope, receive, send): handler = Mangum(app) ``` -Or using a framework. +Or using a framework: ```python from fastapi import FastAPI diff --git a/docs/asgi-frameworks.md b/docs/asgi-frameworks.md index bc8385af..f24509ea 100644 --- a/docs/asgi-frameworks.md +++ b/docs/asgi-frameworks.md @@ -30,7 +30,7 @@ None of the framework details are important here. The routing decorator, request ```python class Application(Protocol): - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None: ... ``` diff --git a/docs/index.md b/docs/index.md index e871acfb..af6ba7a5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,13 +8,13 @@ PyPI - Python Version -Mangum is an adapter for using [ASGI](https://asgi.readthedocs.io/en/latest/) applications with AWS Lambda & API Gateway. It is intended to provide an easy-to-use, configurable wrapper for any ASGI application deployed in an AWS Lambda function to handle API Gateway requests and responses. +Mangum is an adapter for running [ASGI](https://asgi.readthedocs.io/en/latest/) applications in AWS Lambda to handle API Gateway, ALB, and Lambda@Edge events. ***Documentation***: https://mangum.io/ ## Features -- API Gateway support for [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) and [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) APIs. +- Event handlers for API Gateway [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) and [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) APIs, [Application Load Balancer](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html), and [CloudFront Lambda@Edge](https://docs.aws.amazon.com/lambda/latest/dg/lambda-edge.html). - Compatibility with ASGI application frameworks, such as [Starlette](https://www.starlette.io/), [FastAPI](https://fastapi.tiangolo.com/), and [Quart](https://pgjones.gitlab.io/quart/). @@ -26,7 +26,7 @@ Mangum is an adapter for using [ASGI](https://asgi.readthedocs.io/en/latest/) ap ## Requirements -Python 3.6+ +Python 3.7+ ## Installation @@ -53,7 +53,7 @@ async def app(scope, receive, send): handler = Mangum(app) ``` -Or using a framework. +Or using a framework: ```python from fastapi import FastAPI diff --git a/mangum/__init__.py b/mangum/__init__.py index a81cdfb0..fa5058b8 100644 --- a/mangum/__init__.py +++ b/mangum/__init__.py @@ -1,4 +1,3 @@ -from .types import Request, Response -from .adapter import Mangum # noqa: F401 +from mangum.adapter import Mangum -__all__ = ["Mangum", "Request", "Response"] +__all__ = ["Mangum"] diff --git a/mangum/adapter.py b/mangum/adapter.py index de1f9c6d..31a0b468 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -1,66 +1,92 @@ +from itertools import chain import logging from contextlib import ExitStack +from typing import List, Optional, Type +import warnings - -from mangum.exceptions import ConfigurationError -from mangum.handlers.abstract_handler import AbstractHandler from mangum.protocols import HTTPCycle, LifespanCycle -from mangum.types import ASGIApp, LambdaEvent, LambdaContext +from mangum.handlers import ALB, HTTPGateway, APIGateway, LambdaAtEdge +from mangum.exceptions import ConfigurationError +from mangum.types import ( + ASGIApp, + LifespanMode, + LambdaConfig, + LambdaEvent, + LambdaContext, + LambdaHandler, +) -DEFAULT_TEXT_MIME_TYPES = [ - "text/", - "application/json", - "application/javascript", - "application/xml", - "application/vnd.api+json", -] +logger = logging.getLogger("mangum") -logger = logging.getLogger("mangum") +HANDLERS: List[Type[LambdaHandler]] = [ + ALB, + HTTPGateway, + APIGateway, + LambdaAtEdge, +] class Mangum: - """ - Creates an adapter instance. - - * **app** - An asynchronous callable that conforms to version 3.0 of the ASGI - specification. This will usually be an ASGI framework application instance. - * **lifespan** - A string to configure lifespan support. Choices are `auto`, `on`, - and `off`. Default is `auto`. - * **text_mime_types** - A list of MIME types to include with the defaults that - should not return a binary response in API Gateway. - * **api_gateway_base_path** - A string specifying the part of the url path after - which the server routing begins. - """ - def __init__( self, app: ASGIApp, - lifespan: str = "auto", + lifespan: LifespanMode = "auto", api_gateway_base_path: str = "/", + custom_handlers: Optional[List[Type[LambdaHandler]]] = None, ) -> None: + if lifespan not in ("auto", "on", "off"): + raise ConfigurationError( + "Invalid argument supplied for `lifespan`. Choices are: auto|on|off" + ) + self.app = app self.lifespan = lifespan - self.api_gateway_base_path = api_gateway_base_path + self.api_gateway_base_path = api_gateway_base_path or "/" + self.config = LambdaConfig(api_gateway_base_path=self.api_gateway_base_path) - if self.lifespan not in ("auto", "on", "off"): - raise ConfigurationError( - "Invalid argument supplied for `lifespan`. Choices are: auto|on|off" + if custom_handlers is not None: + warnings.warn( # pragma: no cover + "Support for custom event handlers is currently provisional and may " + "drastically change (or be removed entirely) in the future.", + FutureWarning, ) - def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict: - logger.debug("Event received.") + self.custom_handlers = custom_handlers or [] + + def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler: + for handler_cls in chain( + self.custom_handlers, + HANDLERS, + ): + handler = handler_cls.infer( + event, + context, + self.config, + ) + if handler: + break + else: + raise RuntimeError( # pragma: no cover + "The adapter was unable to infer a handler to use for the event. This " + "is likely related to how the Lambda function was invoked. (Are you " + "testing locally? Make sure the request payload is valid for a " + "supported handler.)" + ) + + return handler + def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict: + handler = self.infer(event, context) with ExitStack() as stack: - if self.lifespan != "off": + if self.lifespan in ("auto", "on"): lifespan_cycle = LifespanCycle(self.app, self.lifespan) stack.enter_context(lifespan_cycle) - handler = AbstractHandler.from_trigger( - event, context, self.api_gateway_base_path - ) - http_cycle = HTTPCycle(handler.request) - response = http_cycle(self.app, handler.body) + http_cycle = HTTPCycle(handler.scope, handler.body) + http_response = http_cycle(self.app) + + return handler(http_response) - return handler.transform_response(response) + assert False, "unreachable" # pragma: no cover diff --git a/mangum/handlers/__init__.py b/mangum/handlers/__init__.py index e69de29b..d92f8641 100644 --- a/mangum/handlers/__init__.py +++ b/mangum/handlers/__init__.py @@ -0,0 +1,6 @@ +from mangum.handlers.api_gateway import APIGateway, HTTPGateway +from mangum.handlers.alb import ALB +from mangum.handlers.lambda_at_edge import LambdaAtEdge + + +__all__ = ["APIGateway", "HTTPGateway", "ALB", "LambdaAtEdge"] diff --git a/mangum/handlers/abstract_handler.py b/mangum/handlers/abstract_handler.py deleted file mode 100644 index 666f633a..00000000 --- a/mangum/handlers/abstract_handler.py +++ /dev/null @@ -1,135 +0,0 @@ -import base64 -from abc import ABCMeta, abstractmethod -from typing import Dict, Any, Tuple, List - -from mangum.types import Response, Request, LambdaEvent, LambdaContext - - -class AbstractHandler(metaclass=ABCMeta): - def __init__( - self, - trigger_event: LambdaEvent, - trigger_context: LambdaContext, - ): - self.trigger_event = trigger_event - self.trigger_context = trigger_context - - @property - @abstractmethod - def request(self) -> Request: - """ - Parse an ASGI scope from the request event - """ - - @property - @abstractmethod - def body(self) -> bytes: - """ - Get the raw body from the request event - """ - - @abstractmethod - def transform_response(self, response: Response) -> Dict[str, Any]: - """ - After running our application, transform the response to the correct format for - this handler - """ - - @staticmethod - def from_trigger( - trigger_event: LambdaEvent, - trigger_context: LambdaContext, - api_gateway_base_path: str = "/", - ) -> "AbstractHandler": - """ - A factory method that determines which handler to use. All this code should - probably stay in one place to make sure we are able to uniquely find each - handler correctly. - """ - - # These should be ordered from most specific to least for best accuracy - if ( - "requestContext" in trigger_event - and "elb" in trigger_event["requestContext"] - ): - from mangum.handlers.aws_alb import AwsAlb - - return AwsAlb(trigger_event, trigger_context) - if ( - "Records" in trigger_event - and len(trigger_event["Records"]) > 0 - and "cf" in trigger_event["Records"][0] - ): - from mangum.handlers.aws_cf_lambda_at_edge import AwsCfLambdaAtEdge - - return AwsCfLambdaAtEdge(trigger_event, trigger_context) - - if "version" in trigger_event and "requestContext" in trigger_event: - from mangum.handlers.aws_http_gateway import AwsHttpGateway - - return AwsHttpGateway( - trigger_event, - trigger_context, - api_gateway_base_path, - ) - - if "resource" in trigger_event: - from mangum.handlers.aws_api_gateway import AwsApiGateway - - return AwsApiGateway( - trigger_event, - trigger_context, - api_gateway_base_path, - ) - - raise TypeError("Unable to determine handler from trigger event") - - @staticmethod - def _handle_multi_value_headers( - response_headers: List[List[bytes]], - ) -> Tuple[Dict[str, str], Dict[str, List[str]]]: - headers: Dict[str, str] = {} - multi_value_headers: Dict[str, List[str]] = {} - for key, value in response_headers: - lower_key = key.decode().lower() - if lower_key in multi_value_headers: - multi_value_headers[lower_key].append(value.decode()) - elif lower_key in headers: - # Move existing to multi_value_headers and append current - multi_value_headers[lower_key] = [ - headers[lower_key], - value.decode(), - ] - del headers[lower_key] - else: - headers[lower_key] = value.decode() - return headers, multi_value_headers - - @staticmethod - def _handle_base64_response_body( - body: bytes, headers: Dict[str, str] - ) -> Tuple[str, bool]: - """ - To ease debugging for our users, try and return strings where we can, - otherwise to ensure maximum compatibility with binary data, base64 encode it. - """ - is_base64_encoded = False - output_body = "" - if body != b"": - from ..adapter import DEFAULT_TEXT_MIME_TYPES - - for text_mime_type in DEFAULT_TEXT_MIME_TYPES: - if text_mime_type in headers.get("content-type", ""): - try: - output_body = body.decode() - except UnicodeDecodeError: - # Can't decode it, base64 it and be done - output_body = base64.b64encode(body).decode() - is_base64_encoded = True - break - else: - # Not text, base64 encode - output_body = base64.b64encode(body).decode() - is_base64_encoded = True - - return output_body, is_base64_encoded diff --git a/mangum/handlers/alb.py b/mangum/handlers/alb.py new file mode 100644 index 00000000..23dc62b9 --- /dev/null +++ b/mangum/handlers/alb.py @@ -0,0 +1,178 @@ +from itertools import islice +from typing import Dict, Generator, List, Optional, Tuple +from urllib.parse import urlencode, unquote, unquote_plus + + +from mangum.handlers.utils import ( + get_server_and_port, + handle_base64_response_body, + maybe_encode_body, +) +from mangum.types import ( + HTTPResponse, + HTTPScope, + LambdaConfig, + LambdaEvent, + LambdaContext, + LambdaHandler, + QueryParams, +) + + +def all_casings(input_string: str) -> Generator[str, None, None]: + """ + Permute all casings of a given string. + A pretty algoritm, via @Amber + http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python + """ + if not input_string: + yield "" + else: + first = input_string[:1] + if first.lower() == first.upper(): + for sub_casing in all_casings(input_string[1:]): + yield first + sub_casing + else: + for sub_casing in all_casings(input_string[1:]): + yield first.lower() + sub_casing + yield first.upper() + sub_casing + + +def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]: + """Create str/str key/value headers, with duplicate keys case mutated.""" + headers: Dict[str, str] = {} + for key, values in multi_value_headers.items(): + if len(values) > 0: + casings = list(islice(all_casings(key), len(values))) + for value, cased_key in zip(values, casings): + headers[cased_key] = value + return headers + + +def encode_query_string_for_alb(params: QueryParams) -> bytes: + """Encode the query string parameters for the ALB event. The parameters must be + decoded and then encoded again to prevent double encoding. + + According to the docs: + + "If the query parameters are URL-encoded, the load balancer does not decode + "them. You must decode them in your Lambda function." + """ + params = { + unquote_plus(key): unquote_plus(value) + if isinstance(value, str) + else tuple(unquote_plus(element) for element in value) + for key, value in params.items() + } + query_string = urlencode(params, doseq=True).encode() + + return query_string + + +def transform_headers(event: LambdaEvent) -> List[Tuple[bytes, bytes]]: + headers: List[Tuple[bytes, bytes]] = [] + if "multiValueHeaders" in event: + for k, v in event["multiValueHeaders"].items(): + for inner_v in v: + headers.append((k.lower().encode(), inner_v.encode())) + else: + for k, v in event["headers"].items(): + headers.append((k.lower().encode(), v.encode())) + + return headers + + +class ALB: + @classmethod + def infer( + cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> Optional[LambdaHandler]: + if "requestContext" in event and "elb" in event["requestContext"]: + return cls(event, context, config) + + return None + + def __init__( + self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> None: + self.event = event + self.context = context + self.config = config + + @property + def body(self) -> bytes: + return maybe_encode_body( + self.event.get("body", b""), + is_base64=self.event.get("isBase64Encoded", False), + ) + + @property + def scope(self) -> HTTPScope: + + headers = transform_headers(self.event) + list_headers = [list(x) for x in headers] + # Unique headers. If there are duplicates, it will use the last defined. + uq_headers = {k.decode(): v.decode() for k, v in headers} + source_ip = uq_headers.get("x-forwarded-for", "") + path = unquote(self.event["path"]) if self.event["path"] else "/" + http_method = self.event["httpMethod"] + + params = self.event.get( + "multiValueQueryStringParameters", + self.event.get("queryStringParameters", {}), + ) + if not params: + query_string = b"" + else: + query_string = encode_query_string_for_alb(params) + + server = get_server_and_port(uq_headers) + client = (source_ip, 0) + + scope: HTTPScope = { + "type": "http", + "method": http_method, + "http_version": "1.1", + "headers": list_headers, + "path": path, + "raw_path": None, + "root_path": "", + "scheme": uq_headers.get("x-forwarded-proto", "https"), + "query_string": query_string, + "server": server, + "client": client, + "asgi": {"version": "3.0", "spec_version": "2.0"}, + "aws.event": self.event, + "aws.context": self.context, + } + + return scope + + def __call__(self, response: HTTPResponse) -> dict: + multi_value_headers: Dict[str, List[str]] = {} + for key, value in response["headers"]: + lower_key = key.decode().lower() + if lower_key not in multi_value_headers: + multi_value_headers[lower_key] = [] + multi_value_headers[lower_key].append(value.decode()) + + finalized_headers = case_mutated_headers(multi_value_headers) + finalized_body, is_base64_encoded = handle_base64_response_body( + response["body"], finalized_headers + ) + + out = { + "statusCode": response["status"], + "body": finalized_body, + "isBase64Encoded": is_base64_encoded, + } + + # You must use multiValueHeaders if you have enabled multi-value headers and + # headers otherwise. + multi_value_headers_enabled = "multiValueHeaders" in self.scope["aws.event"] + if multi_value_headers_enabled: + out["multiValueHeaders"] = multi_value_headers + else: + out["headers"] = finalized_headers + + return out diff --git a/mangum/handlers/api_gateway.py b/mangum/handlers/api_gateway.py new file mode 100644 index 00000000..682642e2 --- /dev/null +++ b/mangum/handlers/api_gateway.py @@ -0,0 +1,239 @@ +from typing import Dict, List, Optional, Tuple +from urllib.parse import urlencode + +from mangum.handlers.utils import ( + get_server_and_port, + handle_base64_response_body, + handle_multi_value_headers, + maybe_encode_body, + strip_api_gateway_path, +) +from mangum.types import ( + HTTPResponse, + LambdaConfig, + Headers, + LambdaEvent, + LambdaContext, + LambdaHandler, + QueryParams, + HTTPScope, +) + + +def _encode_query_string_for_apigw(event: LambdaEvent) -> bytes: + params: QueryParams = event.get("multiValueQueryStringParameters", {}) + if not params: + params = event.get("queryStringParameters", {}) + if not params: + return b"" + + return urlencode(params, doseq=True).encode() + + +def _handle_multi_value_headers_for_request(event: LambdaEvent) -> Dict[str, str]: + headers = event.get("headers", {}) or {} + headers = {k.lower(): v for k, v in headers.items()} + if event.get("multiValueHeaders"): + headers.update( + { + k.lower(): ", ".join(v) if isinstance(v, list) else "" + for k, v in event.get("multiValueHeaders", {}).items() + } + ) + + return headers + + +def _combine_headers_v2( + input_headers: Headers, +) -> Tuple[Dict[str, str], List[str]]: + output_headers: Dict[str, str] = {} + cookies: List[str] = [] + for key, value in input_headers: + normalized_key: str = key.decode().lower() + normalized_value: str = value.decode() + if normalized_key == "set-cookie": + cookies.append(normalized_value) + else: + if normalized_key in output_headers: + normalized_value = ( + f"{output_headers[normalized_key]},{normalized_value}" + ) + output_headers[normalized_key] = normalized_value + + return output_headers, cookies + + +class APIGateway: + @classmethod + def infer( + cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> Optional[LambdaHandler]: + if "resource" in event and "requestContext" in event: + return cls(event, context, config) + + return None + + def __init__( + self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> None: + self.event = event + self.context = context + self.config = config + + @property + def body(self) -> bytes: + return maybe_encode_body( + self.event.get("body", b""), + is_base64=self.event.get("isBase64Encoded", False), + ) + + @property + def scope(self) -> HTTPScope: + headers = _handle_multi_value_headers_for_request(self.event) + return { + "type": "http", + "http_version": "1.1", + "method": self.event["httpMethod"], + "headers": [[k.encode(), v.encode()] for k, v in headers.items()], + "path": strip_api_gateway_path( + self.event["path"], + api_gateway_base_path=self.config["api_gateway_base_path"], + ), + "raw_path": None, + "root_path": "", + "scheme": headers.get("x-forwarded-proto", "https"), + "query_string": _encode_query_string_for_apigw(self.event), + "server": get_server_and_port(headers), + "client": ( + self.event["requestContext"].get("identity", {}).get("sourceIp"), + 0, + ), + "asgi": {"version": "3.0", "spec_version": "2.0"}, + "aws.event": self.event, + "aws.context": self.context, + } + + def __call__(self, response: HTTPResponse) -> dict: + finalized_headers, multi_value_headers = handle_multi_value_headers( + response["headers"] + ) + finalized_body, is_base64_encoded = handle_base64_response_body( + response["body"], finalized_headers + ) + + return { + "statusCode": response["status"], + "headers": finalized_headers, + "multiValueHeaders": multi_value_headers, + "body": finalized_body, + "isBase64Encoded": is_base64_encoded, + } + + +class HTTPGateway: + @classmethod + def infer( + cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> Optional[LambdaHandler]: + if "version" in event and "requestContext" in event: + return cls(event, context, config) + + return None + + def __init__( + self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> None: + self.event = event + self.context = context + self.config = config + + @property + def body(self) -> bytes: + return maybe_encode_body( + self.event.get("body", b""), + is_base64=self.event.get("isBase64Encoded", False), + ) + + @property + def scope(self) -> HTTPScope: + request_context = self.event["requestContext"] + event_version = self.event["version"] + + # API Gateway v2 + if event_version == "2.0": + headers = {k.lower(): v for k, v in self.event.get("headers", {}).items()} + source_ip = request_context["http"]["sourceIp"] + path = request_context["http"]["path"] + http_method = request_context["http"]["method"] + query_string = self.event.get("rawQueryString", "").encode() + + if self.event.get("cookies"): + headers["cookie"] = "; ".join(self.event.get("cookies", [])) + + # API Gateway v1 + else: + headers = _handle_multi_value_headers_for_request(self.event) + source_ip = request_context.get("identity", {}).get("sourceIp") + path = self.event["path"] + http_method = self.event["httpMethod"] + query_string = _encode_query_string_for_apigw(self.event) + + path = strip_api_gateway_path( + path, + api_gateway_base_path=self.config["api_gateway_base_path"], + ) + server = get_server_and_port(headers) + client = (source_ip, 0) + + return { + "type": "http", + "method": http_method, + "http_version": "1.1", + "headers": [[k.encode(), v.encode()] for k, v in headers.items()], + "path": path, + "raw_path": None, + "root_path": "", + "scheme": headers.get("x-forwarded-proto", "https"), + "query_string": query_string, + "server": server, + "client": client, + "asgi": {"version": "3.0", "spec_version": "2.0"}, + "aws.event": self.event, + "aws.context": self.context, + } + + def __call__(self, response: HTTPResponse) -> dict: + if self.scope["aws.event"]["version"] == "2.0": + finalized_headers, cookies = _combine_headers_v2(response["headers"]) + + if "content-type" not in finalized_headers and response["body"] is not None: + finalized_headers["content-type"] = "application/json" + + finalized_body, is_base64_encoded = handle_base64_response_body( + response["body"], finalized_headers + ) + response_out = { + "statusCode": response["status"], + "body": finalized_body, + "headers": finalized_headers or None, + "cookies": cookies or None, + "isBase64Encoded": is_base64_encoded, + } + return { + key: value for key, value in response_out.items() if value is not None + } + + finalized_headers, multi_value_headers = handle_multi_value_headers( + response["headers"] + ) + finalized_body, is_base64_encoded = handle_base64_response_body( + response["body"], finalized_headers + ) + return { + "statusCode": response["status"], + "headers": finalized_headers, + "multiValueHeaders": multi_value_headers, + "body": finalized_body, + "isBase64Encoded": is_base64_encoded, + } diff --git a/mangum/handlers/aws_alb.py b/mangum/handlers/aws_alb.py deleted file mode 100644 index c93f5bde..00000000 --- a/mangum/handlers/aws_alb.py +++ /dev/null @@ -1,161 +0,0 @@ -import base64 -from urllib.parse import urlencode, unquote, unquote_plus -from typing import Any, Dict, Generator, List, Tuple -from itertools import islice - -from mangum.types import Response, Request, QueryParams -from mangum.handlers.abstract_handler import AbstractHandler - - -def all_casings(input_string: str) -> Generator[str, None, None]: - """ - Permute all casings of a given string. - A pretty algoritm, via @Amber - http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python - """ - if not input_string: - yield "" - else: - first = input_string[:1] - if first.lower() == first.upper(): - for sub_casing in all_casings(input_string[1:]): - yield first + sub_casing - else: - for sub_casing in all_casings(input_string[1:]): - yield first.lower() + sub_casing - yield first.upper() + sub_casing - - -def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]: - """Create str/str key/value headers, with duplicate keys case mutated.""" - headers: Dict[str, str] = {} - for key, values in multi_value_headers.items(): - if len(values) > 0: - casings = list(islice(all_casings(key), len(values))) - for value, cased_key in zip(values, casings): - headers[cased_key] = value - return headers - - -class AwsAlb(AbstractHandler): - def _encode_query_string(self) -> bytes: - """ - Encodes the queryStringParameters. - The parameters must be decoded, and then encoded again to prevent double - encoding. - - https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501 - "If the query parameters are URL-encoded, the load balancer does not decode - them. You must decode them in your Lambda function." - - Issue: https://github.com/jordaneremieff/mangum/issues/178 - """ - - params: QueryParams = self.trigger_event.get( - "multiValueQueryStringParameters", {} - ) - if not params: - params = self.trigger_event.get("queryStringParameters", {}) - if not params: - return b"" - params = { - unquote_plus(key): unquote_plus(value) - if isinstance(value, str) - else tuple(unquote_plus(element) for element in value) - for key, value in params.items() - } - return urlencode(params, doseq=True).encode() - - def transform_headers(self) -> List[Tuple[bytes, bytes]]: - """Convert headers to a list of two-tuples per ASGI spec. - - Only one of `multiValueHeaders` or `headers` should be defined in the - trigger event. However, we act as though they both might exist and pull - headers out of both. - """ - headers: List[Tuple[bytes, bytes]] = [] - if "multiValueHeaders" in self.trigger_event: - for k, v in self.trigger_event["multiValueHeaders"].items(): - for inner_v in v: - headers.append((k.lower().encode(), inner_v.encode())) - else: - for k, v in self.trigger_event["headers"].items(): - headers.append((k.lower().encode(), v.encode())) - return headers - - @property - def request(self) -> Request: - event = self.trigger_event - - headers = self.transform_headers() - list_headers = [list(x) for x in headers] - # Unique headers. If there are duplicates, it will use the last defined. - uq_headers = {k.decode(): v.decode() for k, v in headers} - - source_ip = uq_headers.get("x-forwarded-for", "") - path = unquote(event["path"]) if event["path"] else "/" - http_method = event["httpMethod"] - query_string = self._encode_query_string() - - server_name = uq_headers.get("host", "mangum") - if ":" not in server_name: - server_port = uq_headers.get("x-forwarded-port", 80) - else: - server_name, server_port = server_name.split(":") # pragma: no cover - server = (server_name, int(server_port)) - client = (source_ip, 0) - - return Request( - method=http_method, - headers=list_headers, - path=path, - scheme=uq_headers.get("x-forwarded-proto", "https"), - query_string=query_string, - server=server, - client=client, - trigger_event=self.trigger_event, - trigger_context=self.trigger_context, - ) - - @property - def body(self) -> bytes: - body = self.trigger_event.get("body", b"") or b"" - - if self.trigger_event.get("isBase64Encoded", False): - return base64.b64decode(body) - if not isinstance(body, bytes): - body = body.encode() - - return body - - def transform_response(self, response: Response) -> Dict[str, Any]: - - multi_value_headers: Dict[str, List[str]] = {} - for key, value in response.headers: - lower_key = key.decode().lower() - if lower_key not in multi_value_headers: - multi_value_headers[lower_key] = [] - multi_value_headers[lower_key].append(value.decode()) - - headers = case_mutated_headers(multi_value_headers) - - body, is_base64_encoded = self._handle_base64_response_body( - response.body, headers - ) - - out = { - "statusCode": response.status, - "body": body, - "isBase64Encoded": is_base64_encoded, - } - - # "You must use multiValueHeaders if you have enabled multi-value headers - # and headers otherwise" - # https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html - multi_value_headers_enabled = "multiValueHeaders" in self.trigger_event - if multi_value_headers_enabled: - out["multiValueHeaders"] = multi_value_headers - else: - out["headers"] = headers - - return out diff --git a/mangum/handlers/aws_api_gateway.py b/mangum/handlers/aws_api_gateway.py deleted file mode 100644 index c97b684b..00000000 --- a/mangum/handlers/aws_api_gateway.py +++ /dev/null @@ -1,116 +0,0 @@ -import base64 -from urllib.parse import urlencode, unquote -from typing import Dict, Any - -from mangum.handlers.abstract_handler import AbstractHandler -from mangum.types import Response, Request, LambdaEvent, LambdaContext, QueryParams - - -class AwsApiGateway(AbstractHandler): - def __init__( - self, - trigger_event: LambdaEvent, - trigger_context: LambdaContext, - api_gateway_base_path: str, - ): - super().__init__(trigger_event, trigger_context) - self.api_gateway_base_path = api_gateway_base_path - - @property - def request(self) -> Request: - event = self.trigger_event - - # See this for more info on headers: - # https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#apigateway-multivalue-headers-and-parameters - headers = {} - # Read headers - if event.get("headers"): - headers.update({k.lower(): v for k, v in event.get("headers", {}).items()}) - # Read multiValueHeaders - # This overrides headers that have the same name - # That means that multiValue versions of headers take precedence - # over their plain versions - if event.get("multiValueHeaders"): - headers.update( - { - k.lower(): ", ".join(v) if isinstance(v, list) else "" - for k, v in event.get("multiValueHeaders", {}).items() - } - ) - - request_context = event["requestContext"] - - source_ip = request_context.get("identity", {}).get("sourceIp") - path = unquote(self._strip_base_path(event["path"])) if event["path"] else "/" - http_method = event["httpMethod"] - query_string = self._encode_query_string() - - server_name = headers.get("host", "mangum") - if ":" not in server_name: - server_port = headers.get("x-forwarded-port", 80) - else: - server_name, server_port = server_name.split(":") # pragma: no cover - server = (server_name, int(server_port)) - client = (source_ip, 0) - - return Request( - method=http_method, - headers=[[k.encode(), v.encode()] for k, v in headers.items()], - path=path, - scheme=headers.get("x-forwarded-proto", "https"), - query_string=query_string, - server=server, - client=client, - trigger_event=self.trigger_event, - trigger_context=self.trigger_context, - ) - - def _encode_query_string(self) -> bytes: - """ - Encodes the queryStringParameters. - """ - - params: QueryParams = self.trigger_event.get( - "multiValueQueryStringParameters", {} - ) - if not params: - params = self.trigger_event.get("queryStringParameters", {}) - if not params: - return b"" - return urlencode(params, doseq=True).encode() - - def _strip_base_path(self, path: str) -> str: - if self.api_gateway_base_path and self.api_gateway_base_path != "/": - if not self.api_gateway_base_path.startswith("/"): - self.api_gateway_base_path = f"/{self.api_gateway_base_path}" - if path.startswith(self.api_gateway_base_path): - path = path[len(self.api_gateway_base_path) :] - return path - - @property - def body(self) -> bytes: - body = self.trigger_event.get("body", b"") or b"" - - if self.trigger_event.get("isBase64Encoded", False): - return base64.b64decode(body) - if not isinstance(body, bytes): - body = body.encode() - - return body - - def transform_response(self, response: Response) -> Dict[str, Any]: - headers, multi_value_headers = self._handle_multi_value_headers( - response.headers - ) - - body, is_base64_encoded = self._handle_base64_response_body( - response.body, headers - ) - - return { - "statusCode": response.status, - "headers": headers, - "multiValueHeaders": multi_value_headers, - "body": body, - "isBase64Encoded": is_base64_encoded, - } diff --git a/mangum/handlers/aws_cf_lambda_at_edge.py b/mangum/handlers/aws_cf_lambda_at_edge.py deleted file mode 100644 index def851ef..00000000 --- a/mangum/handlers/aws_cf_lambda_at_edge.py +++ /dev/null @@ -1,73 +0,0 @@ -import base64 -from typing import Dict, Any, List - -from mangum.handlers.abstract_handler import AbstractHandler -from mangum.types import Response, Request - - -class AwsCfLambdaAtEdge(AbstractHandler): - @property - def request(self) -> Request: - event = self.trigger_event - - cf_request = event["Records"][0]["cf"]["request"] - - scheme_header = cf_request["headers"].get("cloudfront-forwarded-proto", [{}]) - scheme = scheme_header[0].get("value", "https") - - host_header = cf_request["headers"].get("host", [{}]) - server_name = host_header[0].get("value", "mangum") - if ":" not in server_name: - forwarded_port_header = cf_request["headers"].get("x-forwarded-port", [{}]) - server_port = forwarded_port_header[0].get("value", 80) - else: - server_name, server_port = server_name.split(":") # pragma: no cover - server = (server_name, int(server_port)) - - source_ip = cf_request["clientIp"] - client = (source_ip, 0) - - return Request( - method=cf_request["method"], - headers=[ - [k.encode(), v[0]["value"].encode()] - for k, v in cf_request["headers"].items() - ], - path=cf_request["uri"], - scheme=scheme, - query_string=cf_request["querystring"].encode(), - server=server, - client=client, - trigger_event=self.trigger_event, - trigger_context=self.trigger_context, - ) - - @property - def body(self) -> bytes: - request = self.trigger_event["Records"][0]["cf"]["request"] - body = request.get("body", {}).get("data", None) or b"" - - if request.get("body", {}).get("encoding", "") == "base64": - return base64.b64decode(body) - if not isinstance(body, bytes): - body = body.encode() - - return body - - def transform_response(self, response: Response) -> Dict[str, Any]: - headers_dict, _ = self._handle_multi_value_headers(response.headers) - body, is_base64_encoded = self._handle_base64_response_body( - response.body, headers_dict - ) - - # Expand headers to weird list of Dict[str, List[Dict[str, str]]] - headers_expanded: Dict[str, List[Dict[str, str]]] = { - key.decode().lower(): [{"key": key.decode().lower(), "value": val.decode()}] - for key, val in response.headers - } - return { - "status": response.status, - "headers": headers_expanded, - "body": body, - "isBase64Encoded": is_base64_encoded, - } diff --git a/mangum/handlers/aws_http_gateway.py b/mangum/handlers/aws_http_gateway.py deleted file mode 100644 index feed1efb..00000000 --- a/mangum/handlers/aws_http_gateway.py +++ /dev/null @@ -1,170 +0,0 @@ -import base64 -import urllib.parse -from typing import Dict, Any, List, Tuple - -from mangum.handlers.aws_api_gateway import AwsApiGateway -from mangum.types import Response, Request - - -class AwsHttpGateway(AwsApiGateway): - """ - Handles AWS HTTP Gateway events (v1.0 and v2.0), transforming them into ASGI Scope - and handling responses - - See: https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html # noqa: E501 - """ - - TYPE = "AWS_HTTP_GATEWAY" - - @property - def event_version(self) -> str: - return self.trigger_event.get("version", "") - - @property - def request(self) -> Request: - event = self.trigger_event - - headers = {} - if event.get("headers"): - headers = {k.lower(): v for k, v in event.get("headers", {}).items()} - - request_context = event["requestContext"] - - # API Gateway v2 - if self.event_version == "2.0": - source_ip = request_context["http"]["sourceIp"] - path = request_context["http"]["path"] - http_method = request_context["http"]["method"] - query_string = event.get("rawQueryString", "").encode() - - if event.get("cookies"): - headers["cookie"] = "; ".join(event.get("cookies", [])) - - # API Gateway v1 - elif self.event_version == "1.0": - # v1.0 of the HTTP Gateway supports multiValueHeaders - if event.get("multiValueHeaders"): - headers.update( - { - k.lower(): ", ".join(v) if isinstance(v, list) else "" - for k, v in event.get("multiValueHeaders", {}).items() - } - ) - - source_ip = request_context.get("identity", {}).get("sourceIp") - path = event["path"] - http_method = event["httpMethod"] - query_string = self._encode_query_string() - else: - raise RuntimeError( - "Unsupported version of HTTP Gateway Spec, only v1.0 and v2.0 are " - "supported." - ) - - server_name = headers.get("host", "mangum") - if ":" not in server_name: - server_port = headers.get("x-forwarded-port", 80) - else: - server_name, server_port = server_name.split(":") # pragma: no cover - server = (server_name, int(server_port)) - client = (source_ip, 0) - - if not path: - path = "/" - else: - path = self._strip_base_path(path) - - return Request( - method=http_method, - headers=[[k.encode(), v.encode()] for k, v in headers.items()], - path=urllib.parse.unquote(path), - scheme=headers.get("x-forwarded-proto", "https"), - query_string=query_string, - server=server, - client=client, - trigger_event=self.trigger_event, - trigger_context=self.trigger_context, - ) - - @property - def body(self) -> bytes: - body = self.trigger_event.get("body", b"") or b"" - - if self.trigger_event.get("isBase64Encoded", False): - return base64.b64decode(body) - if not isinstance(body, bytes): - body = body.encode() - - return body - - def transform_response(self, response: Response) -> Dict[str, Any]: - """ - This handles some unnecessary magic from AWS - - > API Gateway can infer the response format for you - Boooooo - - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.response - """ - if self.event_version == "1.0": - return self.transform_response_v1(response) - elif self.event_version == "2.0": - return self.transform_response_v2(response) - raise RuntimeError( # pragma: no cover - "Misconfigured event unable to return value, unsupported version." - ) - - def transform_response_v1(self, response: Response) -> Dict[str, Any]: - headers, multi_value_headers = self._handle_multi_value_headers( - response.headers - ) - - body, is_base64_encoded = self._handle_base64_response_body( - response.body, headers - ) - return { - "statusCode": response.status, - "headers": headers, - "multiValueHeaders": multi_value_headers, - "body": body, - "isBase64Encoded": is_base64_encoded, - } - - def _combine_headers_v2( - self, input_headers: List[List[bytes]] - ) -> Tuple[Dict[str, str], List[str]]: - output_headers: Dict[str, str] = {} - cookies: List[str] = [] - for key, value in input_headers: - normalized_key: str = key.decode().lower() - normalized_value: str = value.decode() - if normalized_key == "set-cookie": - cookies.append(normalized_value) - else: - if normalized_key in output_headers: - normalized_value = ( - f"{output_headers[normalized_key]},{normalized_value}" - ) - output_headers[normalized_key] = normalized_value - return output_headers, cookies - - def transform_response_v2(self, response_in: Response) -> Dict[str, Any]: - # The API Gateway will infer stuff for us, but we'll just do that inference - # here and keep the output consistent - - headers, cookies = self._combine_headers_v2(response_in.headers) - - if "content-type" not in headers and response_in.body is not None: - headers["content-type"] = "application/json" - - body, is_base64_encoded = self._handle_base64_response_body( - response_in.body, headers - ) - response_out = { - "statusCode": response_in.status, - "body": body, - "headers": headers or None, - "cookies": cookies or None, - "isBase64Encoded": is_base64_encoded, - } - return {key: value for key, value in response_out.items() if value is not None} diff --git a/mangum/handlers/lambda_at_edge.py b/mangum/handlers/lambda_at_edge.py new file mode 100644 index 00000000..7a91147b --- /dev/null +++ b/mangum/handlers/lambda_at_edge.py @@ -0,0 +1,102 @@ +from typing import Dict, List, Optional + +from mangum.handlers.utils import ( + handle_base64_response_body, + handle_multi_value_headers, + maybe_encode_body, +) +from mangum.types import ( + HTTPScope, + HTTPResponse, + LambdaConfig, + LambdaEvent, + LambdaContext, + LambdaHandler, +) + + +class LambdaAtEdge: + @classmethod + def infer( + cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> Optional[LambdaHandler]: + if ( + "Records" in event + and len(event["Records"]) > 0 + and "cf" in event["Records"][0] + ): + return cls(event, context, config) + + # FIXME: Since this is the last in the chain it doesn't get coverage by default, + # just ignoring it for now. + return None # pragma: nocover + + def __init__( + self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> None: + self.event = event + self.context = context + self.config = config + + @property + def body(self) -> bytes: + cf_request_body = self.event["Records"][0]["cf"]["request"].get("body", {}) + return maybe_encode_body( + cf_request_body.get("data"), + is_base64=cf_request_body.get("encoding", "") == "base64", + ) + + @property + def scope(self) -> HTTPScope: + cf_request = self.event["Records"][0]["cf"]["request"] + scheme_header = cf_request["headers"].get("cloudfront-forwarded-proto", [{}]) + scheme = scheme_header[0].get("value", "https") + host_header = cf_request["headers"].get("host", [{}]) + server_name = host_header[0].get("value", "mangum") + if ":" not in server_name: + forwarded_port_header = cf_request["headers"].get("x-forwarded-port", [{}]) + server_port = forwarded_port_header[0].get("value", 80) + else: + server_name, server_port = server_name.split(":") # pragma: no cover + + server = (server_name, int(server_port)) + source_ip = cf_request["clientIp"] + client = (source_ip, 0) + http_method = cf_request["method"] + + return { + "type": "http", + "method": http_method, + "http_version": "1.1", + "headers": [ + [k.encode(), v[0]["value"].encode()] + for k, v in cf_request["headers"].items() + ], + "path": cf_request["uri"], + "raw_path": None, + "root_path": "", + "scheme": scheme, + "query_string": cf_request["querystring"].encode(), + "server": server, + "client": client, + "asgi": {"version": "3.0", "spec_version": "2.0"}, + "aws.event": self.event, + "aws.context": self.context, + } + + def __call__(self, response: HTTPResponse) -> dict: + multi_value_headers, _ = handle_multi_value_headers(response["headers"]) + response_body, is_base64_encoded = handle_base64_response_body( + response["body"], multi_value_headers + ) + finalized_headers: Dict[str, List[Dict[str, str]]] = { + key.decode().lower(): [{"key": key.decode().lower(), "value": val.decode()}] + for key, val in response["headers"] + } + + return { + "status": response["status"], + "headers": finalized_headers, + "body": response_body, + "isBase64Encoded": is_base64_encoded, + } diff --git a/mangum/handlers/utils.py b/mangum/handlers/utils.py new file mode 100644 index 00000000..1ec1dfb6 --- /dev/null +++ b/mangum/handlers/utils.py @@ -0,0 +1,90 @@ +import base64 +from typing import Dict, List, Tuple, Union +from urllib.parse import unquote + +from mangum.types import Headers + + +DEFAULT_TEXT_MIME_TYPES = [ + "text/", + "application/json", + "application/javascript", + "application/xml", + "application/vnd.api+json", +] + + +def maybe_encode_body(body: Union[str, bytes], *, is_base64: bool) -> bytes: + body = body or b"" + if is_base64: + body = base64.b64decode(body) + elif not isinstance(body, bytes): + body = body.encode() + + return body + + +def get_server_and_port(headers: dict) -> Tuple[str, int]: + server_name = headers.get("host", "mangum") + if ":" not in server_name: + server_port = headers.get("x-forwarded-port", 80) + else: + server_name, server_port = server_name.split(":") # pragma: no cover + server = (server_name, int(server_port)) + + return server + + +def strip_api_gateway_path(path: str, *, api_gateway_base_path: str) -> str: + if not path: + return "/" + + if api_gateway_base_path and api_gateway_base_path != "/": + if not api_gateway_base_path.startswith("/"): + api_gateway_base_path = f"/{api_gateway_base_path}" + if path.startswith(api_gateway_base_path): + path = path[len(api_gateway_base_path) :] + + return unquote(path) + + +def handle_multi_value_headers( + response_headers: Headers, +) -> Tuple[Dict[str, str], Dict[str, List[str]]]: + headers: Dict[str, str] = {} + multi_value_headers: Dict[str, List[str]] = {} + for key, value in response_headers: + lower_key = key.decode().lower() + if lower_key in multi_value_headers: + multi_value_headers[lower_key].append(value.decode()) + elif lower_key in headers: + # Move existing to multi_value_headers and append current + multi_value_headers[lower_key] = [ + headers[lower_key], + value.decode(), + ] + del headers[lower_key] + else: + headers[lower_key] = value.decode() + return headers, multi_value_headers + + +def handle_base64_response_body( + body: bytes, headers: Dict[str, str] +) -> Tuple[str, bool]: + is_base64_encoded = False + output_body = "" + if body != b"": + for text_mime_type in DEFAULT_TEXT_MIME_TYPES: + if text_mime_type in headers.get("content-type", ""): + try: + output_body = body.decode() + except UnicodeDecodeError: + output_body = base64.b64encode(body).decode() + is_base64_encoded = True + break + else: + output_body = base64.b64encode(body).decode() + is_base64_encoded = True + + return output_body, is_base64_encoded diff --git a/mangum/protocols/http.py b/mangum/protocols/http.py index 33db249e..fd452f5a 100644 --- a/mangum/protocols/http.py +++ b/mangum/protocols/http.py @@ -1,19 +1,23 @@ -import enum import asyncio -from typing import Optional +import enum import logging from io import BytesIO -from dataclasses import dataclass -from .. import Response, Request -from ..types import ASGIApp, Message -from ..exceptions import UnexpectedMessage + +from mangum.types import ( + ASGIApp, + ASGIReceiveEvent, + ASGISendEvent, + HTTPDisconnectEvent, + HTTPScope, + HTTPResponse, +) +from mangum.exceptions import UnexpectedMessage class HTTPCycleState(enum.Enum): """ The state of the ASGI `http` connection. - * **REQUEST** - Initial state. The ASGI application instance will be run with the connection scope containing the `http` type. * **RESPONSE** - The `http.response.start` event has been sent by the application. @@ -30,55 +34,38 @@ class HTTPCycleState(enum.Enum): COMPLETE = enum.auto() -@dataclass class HTTPCycle: - """ - Manages the application cycle for an ASGI `http` connection. - - * **request** - A request object containing the event and context for the connection - scope used to run the ASGI application instance. - * **state** - An enumerated `HTTPCycleState` type that indicates the state of the - ASGI connection. - * **app_queue** - An asyncio queue (FIFO) containing messages to be received by the - application. - * **response** - A dictionary containing the response data to return in AWS Lambda. - """ - - request: Request - state: HTTPCycleState = HTTPCycleState.REQUEST - response: Optional[Response] = None - - def __post_init__(self) -> None: - self.logger: logging.Logger = logging.getLogger("mangum.http") - self.loop = asyncio.get_event_loop() - self.app_queue: asyncio.Queue[Message] = asyncio.Queue() - self.body: BytesIO = BytesIO() - - def __call__(self, app: ASGIApp, initial_body: bytes) -> Response: - self.logger.debug("HTTP cycle starting.") + def __init__(self, scope: HTTPScope, body: bytes) -> None: + self.scope = scope + self.buffer = BytesIO() + self.state = HTTPCycleState.REQUEST + self.logger = logging.getLogger("mangum.http") + self.app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue() self.app_queue.put_nowait( - {"type": "http.request", "body": initial_body, "more_body": False} + { + "type": "http.request", + "body": body, + "more_body": False, + } ) + + def __call__(self, app: ASGIApp) -> HTTPResponse: asgi_instance = self.run(app) - asgi_task = self.loop.create_task(asgi_instance) - self.loop.run_until_complete(asgi_task) + loop = asyncio.get_event_loop() + asgi_task = loop.create_task(asgi_instance) + loop.run_until_complete(asgi_task) - if self.response is None: # pragma: nocover - self.response = Response( - status=500, - body=b"Internal Server Error", - headers=[[b"content-type", b"text/plain; charset=utf-8"]], - ) - return self.response + return { + "status": self.status, + "headers": self.headers, + "body": self.body, + } async def run(self, app: ASGIApp) -> None: - """ - Calls the application with the `http` connection scope. - """ try: - await app(self.request.scope, self.receive, self.send) - except BaseException as exc: - self.logger.error("Exception in 'http' protocol.", exc_info=exc) + await app(self.scope, self.receive, self.send) + except BaseException: + self.logger.exception("An error occurred running the application.") if self.state is HTTPCycleState.REQUEST: await self.send( { @@ -88,63 +75,48 @@ async def run(self, app: ASGIApp) -> None: } ) await self.send( - {"type": "http.response.body", "body": b"Internal Server Error"} + { + "type": "http.response.body", + "body": b"Internal Server Error", + "more_body": False, + } ) elif self.state is not HTTPCycleState.COMPLETE: - self.response = Response( - status=500, - body=b"Internal Server Error", - headers=[[b"content-type", b"text/plain; charset=utf-8"]], - ) + self.status = 500 + self.body = b"Internal Server Error" + self.headers = [[b"content-type", b"text/plain; charset=utf-8"]] - async def receive(self) -> Message: - """ - Awaited by the application to receive ASGI `http` events. - """ + async def receive(self) -> ASGIReceiveEvent: return await self.app_queue.get() # pragma: no cover - async def send(self, message: Message) -> None: - """ - Awaited by the application to send ASGI `http` events. - """ - message_type = message["type"] - + async def send(self, message: ASGISendEvent) -> None: if ( self.state is HTTPCycleState.REQUEST - and message_type == "http.response.start" + and message["type"] == "http.response.start" ): - if self.response is None: - self.response = Response( - status=message["status"], - headers=message.get("headers", []), - body=b"", - ) + self.status = message["status"] + self.headers = message.get("headers", []) self.state = HTTPCycleState.RESPONSE elif ( self.state is HTTPCycleState.RESPONSE - and message_type == "http.response.body" + and message["type"] == "http.response.body" ): + body = message.get("body", b"") more_body = message.get("more_body", False) - - # The body must be completely read before returning the response. - self.body.write(body) - - if not more_body and self.response is not None: - body = self.body.getvalue() - self.body.close() - self.response.body = body + self.buffer.write(body) + if not more_body: + self.body = self.buffer.getvalue() + self.buffer.close() self.state = HTTPCycleState.COMPLETE - await self.app_queue.put({"type": "http.disconnect"}) + await self.app_queue.put(HTTPDisconnectEvent(type="http.disconnect")) self.logger.info( "%s %s %s", - self.request.method, - self.request.path, - self.response.status, + self.scope["method"], + self.scope["path"], + self.status, ) else: - raise UnexpectedMessage( - f"{self.state}: Unexpected '{message_type}' event received." - ) + raise UnexpectedMessage(f"Unexpected {message['type']}") diff --git a/mangum/protocols/lifespan.py b/mangum/protocols/lifespan.py index 4ceb91dd..f1f2a6e2 100644 --- a/mangum/protocols/lifespan.py +++ b/mangum/protocols/lifespan.py @@ -1,12 +1,19 @@ import asyncio +import enum import logging from types import TracebackType from typing import Optional, Type -import enum -from dataclasses import dataclass -from ..types import ASGIApp, Message -from ..exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage + +from mangum.types import ( + ASGIApp, + LifespanMode, + ASGIReceiveEvent, + ASGISendEvent, + LifespanShutdownEvent, + LifespanStartupEvent, +) +from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage class LifespanCycleState(enum.Enum): @@ -33,7 +40,6 @@ class LifespanCycleState(enum.Enum): UNSUPPORTED = enum.auto() -@dataclass class LifespanCycle: """ Manages the application cycle for an ASGI `lifespan` connection. @@ -54,22 +60,19 @@ class LifespanCycle: shutdown flow. """ - app: ASGIApp - lifespan: str - state: LifespanCycleState = LifespanCycleState.CONNECTING - exception: Optional[BaseException] = None - - def __post_init__(self) -> None: - self.logger = logging.getLogger("mangum.lifespan") + def __init__(self, app: ASGIApp, lifespan: LifespanMode) -> None: + self.app = app + self.lifespan = lifespan + self.state: LifespanCycleState = LifespanCycleState.CONNECTING + self.exception: Optional[BaseException] = None self.loop = asyncio.get_event_loop() - self.app_queue: asyncio.Queue[Message] = asyncio.Queue() + self.app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue() self.startup_event: asyncio.Event = asyncio.Event() self.shutdown_event: asyncio.Event = asyncio.Event() + self.logger = logging.getLogger("mangum.lifespan") def __enter__(self) -> None: - """ - Runs the event loop for application startup. - """ + """Runs the event loop for application startup.""" self.loop.create_task(self.run()) self.loop.run_until_complete(self.startup()) @@ -79,17 +82,17 @@ def __exit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - """ - Runs the event loop for application shutdown. - """ + """Runs the event loop for application shutdown.""" self.loop.run_until_complete(self.shutdown()) async def run(self) -> None: - """ - Calls the application with the `lifespan` connection scope. - """ + """Calls the application with the `lifespan` connection scope.""" try: - await self.app({"type": "lifespan"}, self.receive, self.send) + await self.app( + {"type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}}, + self.receive, + self.send, + ) except LifespanUnsupported: self.logger.info("ASGI 'lifespan' protocol appears unsupported.") except (LifespanFailure, UnexpectedMessage) as exc: @@ -100,10 +103,8 @@ async def run(self) -> None: self.startup_event.set() self.shutdown_event.set() - async def receive(self) -> Message: - """ - Awaited by the application to receive ASGI `lifespan` events. - """ + async def receive(self) -> ASGIReceiveEvent: + """Awaited by the application to receive ASGI `lifespan` events.""" if self.state is LifespanCycleState.CONNECTING: # Connection established. The next event returned by the queue will be @@ -120,17 +121,14 @@ async def receive(self) -> Message: return await self.app_queue.get() - async def send(self, message: Message) -> None: - """ - Awaited by the application to send ASGI `lifespan` events. - """ + async def send(self, message: ASGISendEvent) -> None: + """Awaited by the application to send ASGI `lifespan` events.""" message_type = message["type"] self.logger.info( "%s: '%s' event received from application.", self.state, message_type ) if self.state is LifespanCycleState.CONNECTING: - if self.lifespan == "on": raise LifespanFailure( "Lifespan connection failed during startup and lifespan is 'on'." @@ -156,8 +154,8 @@ async def send(self, message: Message) -> None: elif message_type == "lifespan.startup.failed": self.state = LifespanCycleState.FAILED self.startup_event.set() - message = message.get("message", "") - raise LifespanFailure(f"Lifespan startup failure. {message}") + message_value = message.get("message", "") + raise LifespanFailure(f"Lifespan startup failure. {message_value}") elif self.state is LifespanCycleState.SHUTDOWN: if message_type == "lifespan.shutdown.complete": @@ -165,15 +163,13 @@ async def send(self, message: Message) -> None: elif message_type == "lifespan.shutdown.failed": self.state = LifespanCycleState.FAILED self.shutdown_event.set() - message = message.get("message", "") - raise LifespanFailure(f"Lifespan shutdown failure. {message}") + message_value = message.get("message", "") + raise LifespanFailure(f"Lifespan shutdown failure. {message_value}") async def startup(self) -> None: - """ - Pushes the `lifespan` startup event to application queue and handles errors. - """ + """Pushes the `lifespan` startup event to the queue and handles errors.""" self.logger.info("Waiting for application startup.") - await self.app_queue.put({"type": "lifespan.startup"}) + await self.app_queue.put(LifespanStartupEvent(type="lifespan.startup")) await self.startup_event.wait() if self.state is LifespanCycleState.FAILED: raise LifespanFailure(self.exception) @@ -184,11 +180,9 @@ async def startup(self) -> None: self.logger.info("Application startup failed.") async def shutdown(self) -> None: - """ - Pushes the `lifespan` shutdown event to application queue and handles errors. - """ + """Pushes the `lifespan` shutdown event to the queue and handles errors.""" self.logger.info("Waiting for application shutdown.") - await self.app_queue.put({"type": "lifespan.shutdown"}) + await self.app_queue.put(LifespanShutdownEvent(type="lifespan.shutdown")) await self.shutdown_event.wait() if self.state is LifespanCycleState.FAILED: raise LifespanFailure(self.exception) diff --git a/mangum/types.py b/mangum/types.py index 1f8b8b14..bcf4a0d0 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +from __future__ import annotations + from typing import ( List, Tuple, @@ -11,21 +12,11 @@ Awaitable, Callable, ) -from typing_extensions import Protocol, TypeAlias - -QueryParams: TypeAlias = MutableMapping[str, Union[str, Sequence[str]]] -Message: TypeAlias = MutableMapping[str, Any] -Scope: TypeAlias = MutableMapping[str, Any] -Receive: TypeAlias = Callable[[], Awaitable[Message]] -Send: TypeAlias = Callable[[Message], Awaitable[None]] - - -class ASGIApp(Protocol): - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - ... # pragma: no cover +from typing_extensions import Literal, Protocol, TypedDict, TypeAlias LambdaEvent = Dict[str, Any] +QueryParams: TypeAlias = MutableMapping[str, Union[str, Sequence[str]]] class LambdaCognitoIdentity(Protocol): @@ -103,65 +94,150 @@ def get_remaining_time_in_millis(self) -> int: ... # pragma: no cover -@dataclass -class BaseRequest: - """ - A holder for an ASGI scope. Contains additional meta from the event that triggered - the Lambda function. - """ +Headers: TypeAlias = List[List[bytes]] - headers: List[List[bytes]] - path: str - scheme: str - query_string: bytes - server: Tuple[str, int] - client: Tuple[str, int] - # Invocation event - trigger_event: Dict[str, Any] - trigger_context: Union["LambdaContext", Dict[str, Any]] +class HTTPRequestEvent(TypedDict): + type: Literal["http.request"] + body: bytes + more_body: bool - raw_path: Optional[str] = None - root_path: str = "" - @property - def scope(self) -> Scope: - return { - "http_version": "1.1", - "headers": self.headers, - "path": self.path, - "raw_path": self.raw_path, - "root_path": self.root_path, - "scheme": self.scheme, - "query_string": self.query_string, - "server": self.server, - "client": self.client, - "asgi": {"version": "3.0"}, - "aws.event": self.trigger_event, - "aws.context": self.trigger_context, - } - - -@dataclass -class Request(BaseRequest): - """ - A holder for an ASGI scope. Specific for usage with HTTP connections. +class HTTPDisconnectEvent(TypedDict): + type: Literal["http.disconnect"] - https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope - """ - type: str = "http" - method: str = "GET" +class HTTPResponseStartEvent(TypedDict): + type: Literal["http.response.start"] + status: int + headers: Headers - @property - def scope(self) -> Scope: - scope = super().scope - scope.update({"type": self.type, "method": self.method}) - return scope +class HTTPResponseBodyEvent(TypedDict): + type: Literal["http.response.body"] + body: bytes + more_body: bool + + +class LifespanStartupEvent(TypedDict): + type: Literal["lifespan.startup"] + + +class LifespanStartupCompleteEvent(TypedDict): + type: Literal["lifespan.startup.complete"] + + +class LifespanStartupFailedEvent(TypedDict): + type: Literal["lifespan.startup.failed"] + message: str + + +class LifespanShutdownEvent(TypedDict): + type: Literal["lifespan.shutdown"] + + +class LifespanShutdownCompleteEvent(TypedDict): + type: Literal["lifespan.shutdown.complete"] + + +class LifespanShutdownFailedEvent(TypedDict): + type: Literal["lifespan.shutdown.failed"] + message: str + + +ASGIReceiveEvent: TypeAlias = Union[ + HTTPRequestEvent, + HTTPDisconnectEvent, + LifespanStartupEvent, + LifespanShutdownEvent, +] + +ASGISendEvent: TypeAlias = Union[ + HTTPResponseStartEvent, + HTTPResponseBodyEvent, + HTTPDisconnectEvent, + LifespanStartupCompleteEvent, + LifespanStartupFailedEvent, + LifespanShutdownCompleteEvent, + LifespanShutdownFailedEvent, +] -@dataclass -class Response: + +ASGIReceive: TypeAlias = Callable[[], Awaitable[ASGIReceiveEvent]] +ASGISend: TypeAlias = Callable[[ASGISendEvent], Awaitable[None]] + + +class ASGISpec(TypedDict): + spec_version: Literal["2.0"] + version: Literal["3.0"] + + +HTTPScope = TypedDict( + "HTTPScope", + { + "type": Literal["http"], + "asgi": ASGISpec, + "http_version": Literal["1.1"], + "scheme": str, + "method": str, + "path": str, + "raw_path": None, + "root_path": Literal[""], + "query_string": bytes, + "headers": Headers, + "client": Tuple[str, int], + "server": Tuple[str, int], + "aws.event": LambdaEvent, + "aws.context": LambdaContext, + }, +) + + +class LifespanScope(TypedDict): + type: Literal["lifespan"] + asgi: ASGISpec + + +LifespanMode: TypeAlias = Literal["auto", "on", "off"] +Scope: TypeAlias = Union[HTTPScope, LifespanScope] + + +class ASGIApp(Protocol): + async def __call__( + self, scope: Scope, receive: ASGIReceive, send: ASGISend + ) -> None: + ... # pragma: no cover + + +class HTTPResponse(TypedDict): status: int - headers: List[List[bytes]] # ex: [[b'content-type', b'text/plain; charset=utf-8']] + headers: Headers body: bytes + + +class LambdaConfig(TypedDict): + api_gateway_base_path: str + + +class LambdaHandler(Protocol): + @classmethod + def infer( + cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> Optional[LambdaHandler]: + ... # pragma: no cover + + def __init__( + self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> None: + ... # pragma: no cover + + @property + def body(self) -> bytes: + ... # pragma: no cover + + @property + def scope(self) -> HTTPScope: + ... # pragma: no cover + + def __call__(self, response: HTTPResponse) -> dict: + ... # pragma: no cover diff --git a/requirements.txt b/requirements.txt index d7ea423c..64398c30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,11 +5,9 @@ pytest-cov black flake8 starlette -quart; python_version >= '3.7' +quart mypy brotli brotli-asgi -types-dataclasses; python_version < '3.7' # For mypy -# Docs mkdocs mkdocs-material diff --git a/scripts/lint b/scripts/lint index de557e7d..5dd72263 100755 --- a/scripts/lint +++ b/scripts/lint @@ -8,5 +8,5 @@ fi set -x ${PREFIX}black mangum tests --check -${PREFIX}mypy mangum --disallow-untyped-defs --ignore-missing-imports +${PREFIX}mypy mangum ${PREFIX}flake8 mangum tests diff --git a/setup.cfg b/setup.cfg index 78c5ea9c..f3777857 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,3 +11,4 @@ disallow_untyped_calls = True disallow_incomplete_defs = True disallow_untyped_decorators = True ignore_missing_imports = True +show_error_codes = True \ No newline at end of file diff --git a/setup.py b/setup.py index 8506ec49..7641c9eb 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,10 @@ def get_long_description(): packages=find_packages(exclude=["tests*"]), license="MIT", url="https://github.com/jordaneremieff/mangum", - description="AWS Lambda & API Gateway support for ASGI", + description="AWS Lambda support for ASGI applications", long_description=get_long_description(), python_requires=">=3.6", - install_requires=["typing_extensions", "dataclasses; python_version < '3.7'"], + install_requires=["typing_extensions"], package_data={"mangum": ["py.typed"]}, long_description_content_type="text/markdown", author="Jordan Eremieff", @@ -23,7 +23,6 @@ def get_long_description(): "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", diff --git a/tests/handlers/test_abstract_handler.py b/tests/handlers/test_abstract_handler.py deleted file mode 100644 index b38366e3..00000000 --- a/tests/handlers/test_abstract_handler.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest - -from mangum.handlers.abstract_handler import AbstractHandler - - -def test_abstract_handler_unkown_event(): - """ - Test an unknown event, ensure it fails in a consistent way - """ - example_event = {"hello": "world", "foo": "bar"} - example_context = {} - with pytest.raises(TypeError): - AbstractHandler.from_trigger(example_event, example_context) diff --git a/tests/handlers/test_aws_alb.py b/tests/handlers/test_alb.py similarity index 98% rename from tests/handlers/test_aws_alb.py rename to tests/handlers/test_alb.py index 105ae81c..6213088d 100644 --- a/tests/handlers/test_aws_alb.py +++ b/tests/handlers/test_alb.py @@ -8,7 +8,7 @@ import pytest from mangum import Mangum -from mangum.handlers.aws_alb import AwsAlb +from mangum.handlers.alb import ALB def get_mock_aws_alb_event( @@ -199,15 +199,15 @@ def test_aws_alb_scope_real( multi_value_headers, ) example_context = {} - handler = AwsAlb(event, example_context) + handler = ALB(event, example_context, {"api_gateway_base_path": "/"}) scope_path = path if scope_path == "": scope_path = "/" assert type(handler.body) == bytes - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": event, "client": ("72.12.164.125", 0), diff --git a/tests/handlers/test_aws_api_gateway.py b/tests/handlers/test_api_gateway.py similarity index 96% rename from tests/handlers/test_aws_api_gateway.py rename to tests/handlers/test_api_gateway.py index d87b669b..4c7cf0a6 100644 --- a/tests/handlers/test_aws_api_gateway.py +++ b/tests/handlers/test_api_gateway.py @@ -3,7 +3,7 @@ import pytest from mangum import Mangum -from mangum.handlers.aws_api_gateway import AwsApiGateway +from mangum.handlers.api_gateway import APIGateway def get_mock_aws_api_gateway_event( @@ -103,11 +103,11 @@ def test_aws_api_gateway_scope_basic(): "isBase64Encoded": False, } example_context = {} - handler = AwsApiGateway(example_event, example_context, "/") + handler = APIGateway(example_event, example_context, {"api_gateway_base_path": "/"}) assert type(handler.body) == bytes - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": example_event, "client": (None, 0), @@ -200,14 +200,14 @@ def test_aws_api_gateway_scope_real( method, path, multi_value_query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = AwsApiGateway(event, example_context, "/") + handler = APIGateway(event, example_context, {"api_gateway_base_path": "/"}) scope_path = path if scope_path == "": scope_path = "/" - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": event, "client": ("192.168.100.1", 0), diff --git a/tests/handlers/test_custom.py b/tests/handlers/test_custom.py new file mode 100644 index 00000000..24286ce3 --- /dev/null +++ b/tests/handlers/test_custom.py @@ -0,0 +1,77 @@ +from typing import Optional + +from mangum.types import ( + HTTPScope, + Headers, + LambdaConfig, + LambdaContext, + LambdaEvent, + LambdaHandler, +) + + +class CustomHandler: + @classmethod + def infer( + cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> Optional[LambdaHandler]: + if "my-custom-key" in event: + return cls(event, context, config) + + return None + + def __init__( + self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig + ) -> None: + self.event = event + self.context = context + self.config = config + + @property + def body(self) -> bytes: + return b"My request body" + + @property + def scope(self) -> HTTPScope: + headers = {} + return { + "type": "http", + "http_version": "1.1", + "method": "GET", + "headers": [[k.encode(), v.encode()] for k, v in headers.items()], + "path": "/", + "raw_path": None, + "root_path": "", + "scheme": "https", + "query_string": b"", + "server": ("mangum", 8080), + "client": ("127.0.0.1", 0), + "asgi": {"version": "3.0", "spec_version": "2.0"}, + "aws.event": self.event, + "aws.context": self.context, + } + + def __call__(self, *, status: int, headers: Headers, body: bytes) -> dict: + return {"statusCode": status, "headers": {}, "body": body.decode()} + + +def test_custom_handler(): + event = {"my-custom-key": 1} + handler = CustomHandler(event, {}, {"api_gateway_base_path": "/"}) + assert type(handler.body) == bytes + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, + "aws.context": {}, + "aws.event": event, + "client": ("127.0.0.1", 0), + "headers": [], + "http_version": "1.1", + "method": "GET", + "path": "/", + "query_string": b"", + "raw_path": None, + "root_path": "", + "scheme": "https", + "server": ("mangum", 8080), + "type": "http", + } diff --git a/tests/handlers/test_aws_http_gateway.py b/tests/handlers/test_http_gateway.py similarity index 93% rename from tests/handlers/test_aws_http_gateway.py rename to tests/handlers/test_http_gateway.py index 2475014d..df428351 100644 --- a/tests/handlers/test_aws_http_gateway.py +++ b/tests/handlers/test_http_gateway.py @@ -3,7 +3,7 @@ import pytest from mangum import Mangum -from mangum.handlers.aws_http_gateway import AwsHttpGateway +from mangum.handlers.api_gateway import HTTPGateway def get_mock_aws_http_gateway_event_v1( @@ -193,12 +193,15 @@ def test_aws_http_gateway_scope_basic_v1(): "body": "Hello from Lambda!", "isBase64Encoded": False, } + example_context = {} - handler = AwsHttpGateway(example_event, example_context, "/") + handler = HTTPGateway( + example_event, example_context, {"api_gateway_base_path": "/"} + ) assert type(handler.body) == bytes - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": example_event, "client": ("IP", 0), @@ -225,8 +228,10 @@ def test_aws_http_gateway_scope_v1_only_non_multi_headers(): ) del example_event["multiValueQueryStringParameters"] example_context = {} - handler = AwsHttpGateway(example_event, example_context, "/") - assert handler.request.scope["query_string"] == b"hello=world" + handler = HTTPGateway( + example_event, example_context, {"api_gateway_base_path": "/"} + ) + assert handler.scope["query_string"] == b"hello=world" def test_aws_http_gateway_scope_v1_no_headers(): @@ -239,8 +244,10 @@ def test_aws_http_gateway_scope_v1_no_headers(): del example_event["multiValueQueryStringParameters"] del example_event["queryStringParameters"] example_context = {} - handler = AwsHttpGateway(example_event, example_context, "/") - assert handler.request.scope["query_string"] == b"" + handler = HTTPGateway( + example_event, example_context, {"api_gateway_base_path": "/"} + ) + assert handler.scope["query_string"] == b"" def test_aws_http_gateway_scope_basic_v2(): @@ -297,11 +304,13 @@ def test_aws_http_gateway_scope_basic_v2(): "stageVariables": {"stageVariable1": "value1", "stageVariable2": "value2"}, } example_context = {} - handler = AwsHttpGateway(example_event, example_context, "/") + handler = HTTPGateway( + example_event, example_context, {"api_gateway_base_path": "/"} + ) assert type(handler.body) == bytes - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": example_event, "client": ("IP", 0), @@ -322,21 +331,6 @@ def test_aws_http_gateway_scope_basic_v2(): } -def test_aws_http_gateway_scope_bad_version(): - """ - Set a version we don't support - - Version is the only thing that is different here, we should be checking that - specifically - """ - example_event = get_mock_aws_http_gateway_event_v2("GET", "/test", {}, None, False) - example_event["version"] = "9001.1" - example_context = {} - handler = AwsHttpGateway(example_event, example_context, "/") - with pytest.raises(RuntimeError): - handler.request.scope - - @pytest.mark.parametrize( "method,path,query_parameters,req_body,body_base64_encoded,query_string,scope_body", [ @@ -369,14 +363,14 @@ def test_aws_http_gateway_scope_real_v1( method, path, query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = AwsHttpGateway(event, example_context, "/") + handler = HTTPGateway(event, example_context, {"api_gateway_base_path": "/"}) scope_path = path if scope_path == "": scope_path = "/" - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": event, "client": ("192.168.100.1", 0), @@ -435,14 +429,14 @@ def test_aws_http_gateway_scope_real_v2( method, path, query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = AwsHttpGateway(event, example_context, "/") + handler = HTTPGateway(event, example_context, {"api_gateway_base_path": "/"}) scope_path = path if scope_path == "": scope_path = "/" - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": event, "client": ("192.168.100.1", 0), diff --git a/tests/handlers/test_aws_cf_lambda_at_edge.py b/tests/handlers/test_lambda_at_edge.py similarity index 95% rename from tests/handlers/test_aws_cf_lambda_at_edge.py rename to tests/handlers/test_lambda_at_edge.py index f1910ed1..47a53f4e 100644 --- a/tests/handlers/test_aws_cf_lambda_at_edge.py +++ b/tests/handlers/test_lambda_at_edge.py @@ -3,7 +3,7 @@ import pytest from mangum import Mangum -from mangum.handlers.aws_cf_lambda_at_edge import AwsCfLambdaAtEdge +from mangum.handlers.lambda_at_edge import LambdaAtEdge def mock_lambda_at_edge_event( @@ -134,11 +134,13 @@ def test_aws_cf_lambda_at_edge_scope_basic(): ] } example_context = {} - handler = AwsCfLambdaAtEdge(example_event, example_context) + handler = LambdaAtEdge( + example_event, example_context, {"api_gateway_base_path": "/"} + ) assert type(handler.body) == bytes - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": example_event, "client": ("203.0.113.178", 0), @@ -223,10 +225,10 @@ def test_aws_api_gateway_scope_real( method, path, multi_value_query_parameters, req_body, body_base64_encoded ) example_context = {} - handler = AwsCfLambdaAtEdge(event, example_context) + handler = LambdaAtEdge(event, example_context, {"api_gateway_base_path": "/"}) - assert handler.request.scope == { - "asgi": {"version": "3.0"}, + assert handler.scope == { + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": event, "client": ("192.168.100.1", 0), diff --git a/tests/test_http.py b/tests/test_http.py index 825b630e..41798055 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -22,7 +22,7 @@ def test_http_response(mock_aws_api_gateway_event) -> None: async def app(scope, receive, send): assert scope == { - "asgi": {"version": "3.0"}, + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": { "body": None, @@ -274,7 +274,7 @@ async def app(scope, receive, send): def test_set_cookies_v2(mock_http_api_event_v2) -> None: async def app(scope, receive, send): assert scope == { - "asgi": {"version": "3.0"}, + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": { "version": "2.0", @@ -391,7 +391,7 @@ async def app(scope, receive, send): def test_set_cookies_v1(mock_http_api_event_v1) -> None: async def app(scope, receive, send): assert scope == { - "asgi": {"version": "3.0"}, + "asgi": {"version": "3.0", "spec_version": "2.0"}, "aws.context": {}, "aws.event": { "version": "1.0",