diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 19a1acdd..7d28d50b 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -11,28 +11,28 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- - uses: actions/checkout@v2
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- - uses: actions/cache@v2
- name: Configure pip caching
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
- restore-keys: |
- ${{ runner.os }}-pip-
- - name: Install dependencies
- run: |
- pip install -U -r requirements.txt
- - name: Run tests
- run: |
- scripts/test
- - name: Run linters
- run: |
- scripts/lint
- - name: Run codecov
- run: codecov
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/cache@v2
+ name: Configure pip caching
+ with:
+ path: ~/.cache/pip
+ key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
+ restore-keys: |
+ ${{ runner.os }}-pip-
+ - name: Install dependencies
+ run: |
+ pip install -U -r requirements.txt
+ - name: Run tests
+ run: |
+ scripts/test
+ - name: Run linters
+ run: |
+ scripts/lint
+ - name: Run codecov
+ run: codecov
diff --git a/README.md b/README.md
index e871acfb..af6ba7a5 100644
--- a/README.md
+++ b/README.md
@@ -8,13 +8,13 @@
-Mangum is an adapter for using [ASGI](https://asgi.readthedocs.io/en/latest/) applications with AWS Lambda & API Gateway. It is intended to provide an easy-to-use, configurable wrapper for any ASGI application deployed in an AWS Lambda function to handle API Gateway requests and responses.
+Mangum is an adapter for running [ASGI](https://asgi.readthedocs.io/en/latest/) applications in AWS Lambda to handle API Gateway, ALB, and Lambda@Edge events.
***Documentation***: https://mangum.io/
## Features
-- API Gateway support for [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) and [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) APIs.
+- Event handlers for API Gateway [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) and [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) APIs, [Application Load Balancer](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html), and [CloudFront Lambda@Edge](https://docs.aws.amazon.com/lambda/latest/dg/lambda-edge.html).
- Compatibility with ASGI application frameworks, such as [Starlette](https://www.starlette.io/), [FastAPI](https://fastapi.tiangolo.com/), and [Quart](https://pgjones.gitlab.io/quart/).
@@ -26,7 +26,7 @@ Mangum is an adapter for using [ASGI](https://asgi.readthedocs.io/en/latest/) ap
## Requirements
-Python 3.6+
+Python 3.7+
## Installation
@@ -53,7 +53,7 @@ async def app(scope, receive, send):
handler = Mangum(app)
```
-Or using a framework.
+Or using a framework:
```python
from fastapi import FastAPI
diff --git a/docs/asgi-frameworks.md b/docs/asgi-frameworks.md
index bc8385af..f24509ea 100644
--- a/docs/asgi-frameworks.md
+++ b/docs/asgi-frameworks.md
@@ -30,7 +30,7 @@ None of the framework details are important here. The routing decorator, request
```python
class Application(Protocol):
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ async def __call__(self, scope: Scope, receive: ASGIReceive, send: ASGISend) -> None:
...
```
diff --git a/docs/index.md b/docs/index.md
index e871acfb..af6ba7a5 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -8,13 +8,13 @@
-Mangum is an adapter for using [ASGI](https://asgi.readthedocs.io/en/latest/) applications with AWS Lambda & API Gateway. It is intended to provide an easy-to-use, configurable wrapper for any ASGI application deployed in an AWS Lambda function to handle API Gateway requests and responses.
+Mangum is an adapter for running [ASGI](https://asgi.readthedocs.io/en/latest/) applications in AWS Lambda to handle API Gateway, ALB, and Lambda@Edge events.
***Documentation***: https://mangum.io/
## Features
-- API Gateway support for [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) and [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) APIs.
+- Event handlers for API Gateway [HTTP](https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api.html) and [REST](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-rest-api.html) APIs, [Application Load Balancer](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html), and [CloudFront Lambda@Edge](https://docs.aws.amazon.com/lambda/latest/dg/lambda-edge.html).
- Compatibility with ASGI application frameworks, such as [Starlette](https://www.starlette.io/), [FastAPI](https://fastapi.tiangolo.com/), and [Quart](https://pgjones.gitlab.io/quart/).
@@ -26,7 +26,7 @@ Mangum is an adapter for using [ASGI](https://asgi.readthedocs.io/en/latest/) ap
## Requirements
-Python 3.6+
+Python 3.7+
## Installation
@@ -53,7 +53,7 @@ async def app(scope, receive, send):
handler = Mangum(app)
```
-Or using a framework.
+Or using a framework:
```python
from fastapi import FastAPI
diff --git a/mangum/__init__.py b/mangum/__init__.py
index a81cdfb0..fa5058b8 100644
--- a/mangum/__init__.py
+++ b/mangum/__init__.py
@@ -1,4 +1,3 @@
-from .types import Request, Response
-from .adapter import Mangum # noqa: F401
+from mangum.adapter import Mangum
-__all__ = ["Mangum", "Request", "Response"]
+__all__ = ["Mangum"]
diff --git a/mangum/adapter.py b/mangum/adapter.py
index de1f9c6d..31a0b468 100644
--- a/mangum/adapter.py
+++ b/mangum/adapter.py
@@ -1,66 +1,92 @@
+from itertools import chain
import logging
from contextlib import ExitStack
+from typing import List, Optional, Type
+import warnings
-
-from mangum.exceptions import ConfigurationError
-from mangum.handlers.abstract_handler import AbstractHandler
from mangum.protocols import HTTPCycle, LifespanCycle
-from mangum.types import ASGIApp, LambdaEvent, LambdaContext
+from mangum.handlers import ALB, HTTPGateway, APIGateway, LambdaAtEdge
+from mangum.exceptions import ConfigurationError
+from mangum.types import (
+ ASGIApp,
+ LifespanMode,
+ LambdaConfig,
+ LambdaEvent,
+ LambdaContext,
+ LambdaHandler,
+)
-DEFAULT_TEXT_MIME_TYPES = [
- "text/",
- "application/json",
- "application/javascript",
- "application/xml",
- "application/vnd.api+json",
-]
+logger = logging.getLogger("mangum")
-logger = logging.getLogger("mangum")
+HANDLERS: List[Type[LambdaHandler]] = [
+ ALB,
+ HTTPGateway,
+ APIGateway,
+ LambdaAtEdge,
+]
class Mangum:
- """
- Creates an adapter instance.
-
- * **app** - An asynchronous callable that conforms to version 3.0 of the ASGI
- specification. This will usually be an ASGI framework application instance.
- * **lifespan** - A string to configure lifespan support. Choices are `auto`, `on`,
- and `off`. Default is `auto`.
- * **text_mime_types** - A list of MIME types to include with the defaults that
- should not return a binary response in API Gateway.
- * **api_gateway_base_path** - A string specifying the part of the url path after
- which the server routing begins.
- """
-
def __init__(
self,
app: ASGIApp,
- lifespan: str = "auto",
+ lifespan: LifespanMode = "auto",
api_gateway_base_path: str = "/",
+ custom_handlers: Optional[List[Type[LambdaHandler]]] = None,
) -> None:
+ if lifespan not in ("auto", "on", "off"):
+ raise ConfigurationError(
+ "Invalid argument supplied for `lifespan`. Choices are: auto|on|off"
+ )
+
self.app = app
self.lifespan = lifespan
- self.api_gateway_base_path = api_gateway_base_path
+ self.api_gateway_base_path = api_gateway_base_path or "/"
+ self.config = LambdaConfig(api_gateway_base_path=self.api_gateway_base_path)
- if self.lifespan not in ("auto", "on", "off"):
- raise ConfigurationError(
- "Invalid argument supplied for `lifespan`. Choices are: auto|on|off"
+ if custom_handlers is not None:
+ warnings.warn( # pragma: no cover
+ "Support for custom event handlers is currently provisional and may "
+ "drastically change (or be removed entirely) in the future.",
+ FutureWarning,
)
- def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict:
- logger.debug("Event received.")
+ self.custom_handlers = custom_handlers or []
+
+ def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
+ for handler_cls in chain(
+ self.custom_handlers,
+ HANDLERS,
+ ):
+ handler = handler_cls.infer(
+ event,
+ context,
+ self.config,
+ )
+ if handler:
+ break
+ else:
+ raise RuntimeError( # pragma: no cover
+ "The adapter was unable to infer a handler to use for the event. This "
+ "is likely related to how the Lambda function was invoked. (Are you "
+ "testing locally? Make sure the request payload is valid for a "
+ "supported handler.)"
+ )
+
+ return handler
+ def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict:
+ handler = self.infer(event, context)
with ExitStack() as stack:
- if self.lifespan != "off":
+ if self.lifespan in ("auto", "on"):
lifespan_cycle = LifespanCycle(self.app, self.lifespan)
stack.enter_context(lifespan_cycle)
- handler = AbstractHandler.from_trigger(
- event, context, self.api_gateway_base_path
- )
- http_cycle = HTTPCycle(handler.request)
- response = http_cycle(self.app, handler.body)
+ http_cycle = HTTPCycle(handler.scope, handler.body)
+ http_response = http_cycle(self.app)
+
+ return handler(http_response)
- return handler.transform_response(response)
+ assert False, "unreachable" # pragma: no cover
diff --git a/mangum/handlers/__init__.py b/mangum/handlers/__init__.py
index e69de29b..d92f8641 100644
--- a/mangum/handlers/__init__.py
+++ b/mangum/handlers/__init__.py
@@ -0,0 +1,6 @@
+from mangum.handlers.api_gateway import APIGateway, HTTPGateway
+from mangum.handlers.alb import ALB
+from mangum.handlers.lambda_at_edge import LambdaAtEdge
+
+
+__all__ = ["APIGateway", "HTTPGateway", "ALB", "LambdaAtEdge"]
diff --git a/mangum/handlers/abstract_handler.py b/mangum/handlers/abstract_handler.py
deleted file mode 100644
index 666f633a..00000000
--- a/mangum/handlers/abstract_handler.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import base64
-from abc import ABCMeta, abstractmethod
-from typing import Dict, Any, Tuple, List
-
-from mangum.types import Response, Request, LambdaEvent, LambdaContext
-
-
-class AbstractHandler(metaclass=ABCMeta):
- def __init__(
- self,
- trigger_event: LambdaEvent,
- trigger_context: LambdaContext,
- ):
- self.trigger_event = trigger_event
- self.trigger_context = trigger_context
-
- @property
- @abstractmethod
- def request(self) -> Request:
- """
- Parse an ASGI scope from the request event
- """
-
- @property
- @abstractmethod
- def body(self) -> bytes:
- """
- Get the raw body from the request event
- """
-
- @abstractmethod
- def transform_response(self, response: Response) -> Dict[str, Any]:
- """
- After running our application, transform the response to the correct format for
- this handler
- """
-
- @staticmethod
- def from_trigger(
- trigger_event: LambdaEvent,
- trigger_context: LambdaContext,
- api_gateway_base_path: str = "/",
- ) -> "AbstractHandler":
- """
- A factory method that determines which handler to use. All this code should
- probably stay in one place to make sure we are able to uniquely find each
- handler correctly.
- """
-
- # These should be ordered from most specific to least for best accuracy
- if (
- "requestContext" in trigger_event
- and "elb" in trigger_event["requestContext"]
- ):
- from mangum.handlers.aws_alb import AwsAlb
-
- return AwsAlb(trigger_event, trigger_context)
- if (
- "Records" in trigger_event
- and len(trigger_event["Records"]) > 0
- and "cf" in trigger_event["Records"][0]
- ):
- from mangum.handlers.aws_cf_lambda_at_edge import AwsCfLambdaAtEdge
-
- return AwsCfLambdaAtEdge(trigger_event, trigger_context)
-
- if "version" in trigger_event and "requestContext" in trigger_event:
- from mangum.handlers.aws_http_gateway import AwsHttpGateway
-
- return AwsHttpGateway(
- trigger_event,
- trigger_context,
- api_gateway_base_path,
- )
-
- if "resource" in trigger_event:
- from mangum.handlers.aws_api_gateway import AwsApiGateway
-
- return AwsApiGateway(
- trigger_event,
- trigger_context,
- api_gateway_base_path,
- )
-
- raise TypeError("Unable to determine handler from trigger event")
-
- @staticmethod
- def _handle_multi_value_headers(
- response_headers: List[List[bytes]],
- ) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
- headers: Dict[str, str] = {}
- multi_value_headers: Dict[str, List[str]] = {}
- for key, value in response_headers:
- lower_key = key.decode().lower()
- if lower_key in multi_value_headers:
- multi_value_headers[lower_key].append(value.decode())
- elif lower_key in headers:
- # Move existing to multi_value_headers and append current
- multi_value_headers[lower_key] = [
- headers[lower_key],
- value.decode(),
- ]
- del headers[lower_key]
- else:
- headers[lower_key] = value.decode()
- return headers, multi_value_headers
-
- @staticmethod
- def _handle_base64_response_body(
- body: bytes, headers: Dict[str, str]
- ) -> Tuple[str, bool]:
- """
- To ease debugging for our users, try and return strings where we can,
- otherwise to ensure maximum compatibility with binary data, base64 encode it.
- """
- is_base64_encoded = False
- output_body = ""
- if body != b"":
- from ..adapter import DEFAULT_TEXT_MIME_TYPES
-
- for text_mime_type in DEFAULT_TEXT_MIME_TYPES:
- if text_mime_type in headers.get("content-type", ""):
- try:
- output_body = body.decode()
- except UnicodeDecodeError:
- # Can't decode it, base64 it and be done
- output_body = base64.b64encode(body).decode()
- is_base64_encoded = True
- break
- else:
- # Not text, base64 encode
- output_body = base64.b64encode(body).decode()
- is_base64_encoded = True
-
- return output_body, is_base64_encoded
diff --git a/mangum/handlers/alb.py b/mangum/handlers/alb.py
new file mode 100644
index 00000000..23dc62b9
--- /dev/null
+++ b/mangum/handlers/alb.py
@@ -0,0 +1,178 @@
+from itertools import islice
+from typing import Dict, Generator, List, Optional, Tuple
+from urllib.parse import urlencode, unquote, unquote_plus
+
+
+from mangum.handlers.utils import (
+ get_server_and_port,
+ handle_base64_response_body,
+ maybe_encode_body,
+)
+from mangum.types import (
+ HTTPResponse,
+ HTTPScope,
+ LambdaConfig,
+ LambdaEvent,
+ LambdaContext,
+ LambdaHandler,
+ QueryParams,
+)
+
+
+def all_casings(input_string: str) -> Generator[str, None, None]:
+ """
+ Permute all casings of a given string.
+ A pretty algoritm, via @Amber
+ http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python
+ """
+ if not input_string:
+ yield ""
+ else:
+ first = input_string[:1]
+ if first.lower() == first.upper():
+ for sub_casing in all_casings(input_string[1:]):
+ yield first + sub_casing
+ else:
+ for sub_casing in all_casings(input_string[1:]):
+ yield first.lower() + sub_casing
+ yield first.upper() + sub_casing
+
+
+def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]:
+ """Create str/str key/value headers, with duplicate keys case mutated."""
+ headers: Dict[str, str] = {}
+ for key, values in multi_value_headers.items():
+ if len(values) > 0:
+ casings = list(islice(all_casings(key), len(values)))
+ for value, cased_key in zip(values, casings):
+ headers[cased_key] = value
+ return headers
+
+
+def encode_query_string_for_alb(params: QueryParams) -> bytes:
+ """Encode the query string parameters for the ALB event. The parameters must be
+ decoded and then encoded again to prevent double encoding.
+
+ According to the docs:
+
+ "If the query parameters are URL-encoded, the load balancer does not decode
+ "them. You must decode them in your Lambda function."
+ """
+ params = {
+ unquote_plus(key): unquote_plus(value)
+ if isinstance(value, str)
+ else tuple(unquote_plus(element) for element in value)
+ for key, value in params.items()
+ }
+ query_string = urlencode(params, doseq=True).encode()
+
+ return query_string
+
+
+def transform_headers(event: LambdaEvent) -> List[Tuple[bytes, bytes]]:
+ headers: List[Tuple[bytes, bytes]] = []
+ if "multiValueHeaders" in event:
+ for k, v in event["multiValueHeaders"].items():
+ for inner_v in v:
+ headers.append((k.lower().encode(), inner_v.encode()))
+ else:
+ for k, v in event["headers"].items():
+ headers.append((k.lower().encode(), v.encode()))
+
+ return headers
+
+
+class ALB:
+ @classmethod
+ def infer(
+ cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> Optional[LambdaHandler]:
+ if "requestContext" in event and "elb" in event["requestContext"]:
+ return cls(event, context, config)
+
+ return None
+
+ def __init__(
+ self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> None:
+ self.event = event
+ self.context = context
+ self.config = config
+
+ @property
+ def body(self) -> bytes:
+ return maybe_encode_body(
+ self.event.get("body", b""),
+ is_base64=self.event.get("isBase64Encoded", False),
+ )
+
+ @property
+ def scope(self) -> HTTPScope:
+
+ headers = transform_headers(self.event)
+ list_headers = [list(x) for x in headers]
+ # Unique headers. If there are duplicates, it will use the last defined.
+ uq_headers = {k.decode(): v.decode() for k, v in headers}
+ source_ip = uq_headers.get("x-forwarded-for", "")
+ path = unquote(self.event["path"]) if self.event["path"] else "/"
+ http_method = self.event["httpMethod"]
+
+ params = self.event.get(
+ "multiValueQueryStringParameters",
+ self.event.get("queryStringParameters", {}),
+ )
+ if not params:
+ query_string = b""
+ else:
+ query_string = encode_query_string_for_alb(params)
+
+ server = get_server_and_port(uq_headers)
+ client = (source_ip, 0)
+
+ scope: HTTPScope = {
+ "type": "http",
+ "method": http_method,
+ "http_version": "1.1",
+ "headers": list_headers,
+ "path": path,
+ "raw_path": None,
+ "root_path": "",
+ "scheme": uq_headers.get("x-forwarded-proto", "https"),
+ "query_string": query_string,
+ "server": server,
+ "client": client,
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
+ "aws.event": self.event,
+ "aws.context": self.context,
+ }
+
+ return scope
+
+ def __call__(self, response: HTTPResponse) -> dict:
+ multi_value_headers: Dict[str, List[str]] = {}
+ for key, value in response["headers"]:
+ lower_key = key.decode().lower()
+ if lower_key not in multi_value_headers:
+ multi_value_headers[lower_key] = []
+ multi_value_headers[lower_key].append(value.decode())
+
+ finalized_headers = case_mutated_headers(multi_value_headers)
+ finalized_body, is_base64_encoded = handle_base64_response_body(
+ response["body"], finalized_headers
+ )
+
+ out = {
+ "statusCode": response["status"],
+ "body": finalized_body,
+ "isBase64Encoded": is_base64_encoded,
+ }
+
+ # You must use multiValueHeaders if you have enabled multi-value headers and
+ # headers otherwise.
+ multi_value_headers_enabled = "multiValueHeaders" in self.scope["aws.event"]
+ if multi_value_headers_enabled:
+ out["multiValueHeaders"] = multi_value_headers
+ else:
+ out["headers"] = finalized_headers
+
+ return out
diff --git a/mangum/handlers/api_gateway.py b/mangum/handlers/api_gateway.py
new file mode 100644
index 00000000..682642e2
--- /dev/null
+++ b/mangum/handlers/api_gateway.py
@@ -0,0 +1,239 @@
+from typing import Dict, List, Optional, Tuple
+from urllib.parse import urlencode
+
+from mangum.handlers.utils import (
+ get_server_and_port,
+ handle_base64_response_body,
+ handle_multi_value_headers,
+ maybe_encode_body,
+ strip_api_gateway_path,
+)
+from mangum.types import (
+ HTTPResponse,
+ LambdaConfig,
+ Headers,
+ LambdaEvent,
+ LambdaContext,
+ LambdaHandler,
+ QueryParams,
+ HTTPScope,
+)
+
+
+def _encode_query_string_for_apigw(event: LambdaEvent) -> bytes:
+ params: QueryParams = event.get("multiValueQueryStringParameters", {})
+ if not params:
+ params = event.get("queryStringParameters", {})
+ if not params:
+ return b""
+
+ return urlencode(params, doseq=True).encode()
+
+
+def _handle_multi_value_headers_for_request(event: LambdaEvent) -> Dict[str, str]:
+ headers = event.get("headers", {}) or {}
+ headers = {k.lower(): v for k, v in headers.items()}
+ if event.get("multiValueHeaders"):
+ headers.update(
+ {
+ k.lower(): ", ".join(v) if isinstance(v, list) else ""
+ for k, v in event.get("multiValueHeaders", {}).items()
+ }
+ )
+
+ return headers
+
+
+def _combine_headers_v2(
+ input_headers: Headers,
+) -> Tuple[Dict[str, str], List[str]]:
+ output_headers: Dict[str, str] = {}
+ cookies: List[str] = []
+ for key, value in input_headers:
+ normalized_key: str = key.decode().lower()
+ normalized_value: str = value.decode()
+ if normalized_key == "set-cookie":
+ cookies.append(normalized_value)
+ else:
+ if normalized_key in output_headers:
+ normalized_value = (
+ f"{output_headers[normalized_key]},{normalized_value}"
+ )
+ output_headers[normalized_key] = normalized_value
+
+ return output_headers, cookies
+
+
+class APIGateway:
+ @classmethod
+ def infer(
+ cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> Optional[LambdaHandler]:
+ if "resource" in event and "requestContext" in event:
+ return cls(event, context, config)
+
+ return None
+
+ def __init__(
+ self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> None:
+ self.event = event
+ self.context = context
+ self.config = config
+
+ @property
+ def body(self) -> bytes:
+ return maybe_encode_body(
+ self.event.get("body", b""),
+ is_base64=self.event.get("isBase64Encoded", False),
+ )
+
+ @property
+ def scope(self) -> HTTPScope:
+ headers = _handle_multi_value_headers_for_request(self.event)
+ return {
+ "type": "http",
+ "http_version": "1.1",
+ "method": self.event["httpMethod"],
+ "headers": [[k.encode(), v.encode()] for k, v in headers.items()],
+ "path": strip_api_gateway_path(
+ self.event["path"],
+ api_gateway_base_path=self.config["api_gateway_base_path"],
+ ),
+ "raw_path": None,
+ "root_path": "",
+ "scheme": headers.get("x-forwarded-proto", "https"),
+ "query_string": _encode_query_string_for_apigw(self.event),
+ "server": get_server_and_port(headers),
+ "client": (
+ self.event["requestContext"].get("identity", {}).get("sourceIp"),
+ 0,
+ ),
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
+ "aws.event": self.event,
+ "aws.context": self.context,
+ }
+
+ def __call__(self, response: HTTPResponse) -> dict:
+ finalized_headers, multi_value_headers = handle_multi_value_headers(
+ response["headers"]
+ )
+ finalized_body, is_base64_encoded = handle_base64_response_body(
+ response["body"], finalized_headers
+ )
+
+ return {
+ "statusCode": response["status"],
+ "headers": finalized_headers,
+ "multiValueHeaders": multi_value_headers,
+ "body": finalized_body,
+ "isBase64Encoded": is_base64_encoded,
+ }
+
+
+class HTTPGateway:
+ @classmethod
+ def infer(
+ cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> Optional[LambdaHandler]:
+ if "version" in event and "requestContext" in event:
+ return cls(event, context, config)
+
+ return None
+
+ def __init__(
+ self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> None:
+ self.event = event
+ self.context = context
+ self.config = config
+
+ @property
+ def body(self) -> bytes:
+ return maybe_encode_body(
+ self.event.get("body", b""),
+ is_base64=self.event.get("isBase64Encoded", False),
+ )
+
+ @property
+ def scope(self) -> HTTPScope:
+ request_context = self.event["requestContext"]
+ event_version = self.event["version"]
+
+ # API Gateway v2
+ if event_version == "2.0":
+ headers = {k.lower(): v for k, v in self.event.get("headers", {}).items()}
+ source_ip = request_context["http"]["sourceIp"]
+ path = request_context["http"]["path"]
+ http_method = request_context["http"]["method"]
+ query_string = self.event.get("rawQueryString", "").encode()
+
+ if self.event.get("cookies"):
+ headers["cookie"] = "; ".join(self.event.get("cookies", []))
+
+ # API Gateway v1
+ else:
+ headers = _handle_multi_value_headers_for_request(self.event)
+ source_ip = request_context.get("identity", {}).get("sourceIp")
+ path = self.event["path"]
+ http_method = self.event["httpMethod"]
+ query_string = _encode_query_string_for_apigw(self.event)
+
+ path = strip_api_gateway_path(
+ path,
+ api_gateway_base_path=self.config["api_gateway_base_path"],
+ )
+ server = get_server_and_port(headers)
+ client = (source_ip, 0)
+
+ return {
+ "type": "http",
+ "method": http_method,
+ "http_version": "1.1",
+ "headers": [[k.encode(), v.encode()] for k, v in headers.items()],
+ "path": path,
+ "raw_path": None,
+ "root_path": "",
+ "scheme": headers.get("x-forwarded-proto", "https"),
+ "query_string": query_string,
+ "server": server,
+ "client": client,
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
+ "aws.event": self.event,
+ "aws.context": self.context,
+ }
+
+ def __call__(self, response: HTTPResponse) -> dict:
+ if self.scope["aws.event"]["version"] == "2.0":
+ finalized_headers, cookies = _combine_headers_v2(response["headers"])
+
+ if "content-type" not in finalized_headers and response["body"] is not None:
+ finalized_headers["content-type"] = "application/json"
+
+ finalized_body, is_base64_encoded = handle_base64_response_body(
+ response["body"], finalized_headers
+ )
+ response_out = {
+ "statusCode": response["status"],
+ "body": finalized_body,
+ "headers": finalized_headers or None,
+ "cookies": cookies or None,
+ "isBase64Encoded": is_base64_encoded,
+ }
+ return {
+ key: value for key, value in response_out.items() if value is not None
+ }
+
+ finalized_headers, multi_value_headers = handle_multi_value_headers(
+ response["headers"]
+ )
+ finalized_body, is_base64_encoded = handle_base64_response_body(
+ response["body"], finalized_headers
+ )
+ return {
+ "statusCode": response["status"],
+ "headers": finalized_headers,
+ "multiValueHeaders": multi_value_headers,
+ "body": finalized_body,
+ "isBase64Encoded": is_base64_encoded,
+ }
diff --git a/mangum/handlers/aws_alb.py b/mangum/handlers/aws_alb.py
deleted file mode 100644
index c93f5bde..00000000
--- a/mangum/handlers/aws_alb.py
+++ /dev/null
@@ -1,161 +0,0 @@
-import base64
-from urllib.parse import urlencode, unquote, unquote_plus
-from typing import Any, Dict, Generator, List, Tuple
-from itertools import islice
-
-from mangum.types import Response, Request, QueryParams
-from mangum.handlers.abstract_handler import AbstractHandler
-
-
-def all_casings(input_string: str) -> Generator[str, None, None]:
- """
- Permute all casings of a given string.
- A pretty algoritm, via @Amber
- http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python
- """
- if not input_string:
- yield ""
- else:
- first = input_string[:1]
- if first.lower() == first.upper():
- for sub_casing in all_casings(input_string[1:]):
- yield first + sub_casing
- else:
- for sub_casing in all_casings(input_string[1:]):
- yield first.lower() + sub_casing
- yield first.upper() + sub_casing
-
-
-def case_mutated_headers(multi_value_headers: Dict[str, List[str]]) -> Dict[str, str]:
- """Create str/str key/value headers, with duplicate keys case mutated."""
- headers: Dict[str, str] = {}
- for key, values in multi_value_headers.items():
- if len(values) > 0:
- casings = list(islice(all_casings(key), len(values)))
- for value, cased_key in zip(values, casings):
- headers[cased_key] = value
- return headers
-
-
-class AwsAlb(AbstractHandler):
- def _encode_query_string(self) -> bytes:
- """
- Encodes the queryStringParameters.
- The parameters must be decoded, and then encoded again to prevent double
- encoding.
-
- https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501
- "If the query parameters are URL-encoded, the load balancer does not decode
- them. You must decode them in your Lambda function."
-
- Issue: https://github.com/jordaneremieff/mangum/issues/178
- """
-
- params: QueryParams = self.trigger_event.get(
- "multiValueQueryStringParameters", {}
- )
- if not params:
- params = self.trigger_event.get("queryStringParameters", {})
- if not params:
- return b""
- params = {
- unquote_plus(key): unquote_plus(value)
- if isinstance(value, str)
- else tuple(unquote_plus(element) for element in value)
- for key, value in params.items()
- }
- return urlencode(params, doseq=True).encode()
-
- def transform_headers(self) -> List[Tuple[bytes, bytes]]:
- """Convert headers to a list of two-tuples per ASGI spec.
-
- Only one of `multiValueHeaders` or `headers` should be defined in the
- trigger event. However, we act as though they both might exist and pull
- headers out of both.
- """
- headers: List[Tuple[bytes, bytes]] = []
- if "multiValueHeaders" in self.trigger_event:
- for k, v in self.trigger_event["multiValueHeaders"].items():
- for inner_v in v:
- headers.append((k.lower().encode(), inner_v.encode()))
- else:
- for k, v in self.trigger_event["headers"].items():
- headers.append((k.lower().encode(), v.encode()))
- return headers
-
- @property
- def request(self) -> Request:
- event = self.trigger_event
-
- headers = self.transform_headers()
- list_headers = [list(x) for x in headers]
- # Unique headers. If there are duplicates, it will use the last defined.
- uq_headers = {k.decode(): v.decode() for k, v in headers}
-
- source_ip = uq_headers.get("x-forwarded-for", "")
- path = unquote(event["path"]) if event["path"] else "/"
- http_method = event["httpMethod"]
- query_string = self._encode_query_string()
-
- server_name = uq_headers.get("host", "mangum")
- if ":" not in server_name:
- server_port = uq_headers.get("x-forwarded-port", 80)
- else:
- server_name, server_port = server_name.split(":") # pragma: no cover
- server = (server_name, int(server_port))
- client = (source_ip, 0)
-
- return Request(
- method=http_method,
- headers=list_headers,
- path=path,
- scheme=uq_headers.get("x-forwarded-proto", "https"),
- query_string=query_string,
- server=server,
- client=client,
- trigger_event=self.trigger_event,
- trigger_context=self.trigger_context,
- )
-
- @property
- def body(self) -> bytes:
- body = self.trigger_event.get("body", b"") or b""
-
- if self.trigger_event.get("isBase64Encoded", False):
- return base64.b64decode(body)
- if not isinstance(body, bytes):
- body = body.encode()
-
- return body
-
- def transform_response(self, response: Response) -> Dict[str, Any]:
-
- multi_value_headers: Dict[str, List[str]] = {}
- for key, value in response.headers:
- lower_key = key.decode().lower()
- if lower_key not in multi_value_headers:
- multi_value_headers[lower_key] = []
- multi_value_headers[lower_key].append(value.decode())
-
- headers = case_mutated_headers(multi_value_headers)
-
- body, is_base64_encoded = self._handle_base64_response_body(
- response.body, headers
- )
-
- out = {
- "statusCode": response.status,
- "body": body,
- "isBase64Encoded": is_base64_encoded,
- }
-
- # "You must use multiValueHeaders if you have enabled multi-value headers
- # and headers otherwise"
- # https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html
- multi_value_headers_enabled = "multiValueHeaders" in self.trigger_event
- if multi_value_headers_enabled:
- out["multiValueHeaders"] = multi_value_headers
- else:
- out["headers"] = headers
-
- return out
diff --git a/mangum/handlers/aws_api_gateway.py b/mangum/handlers/aws_api_gateway.py
deleted file mode 100644
index c97b684b..00000000
--- a/mangum/handlers/aws_api_gateway.py
+++ /dev/null
@@ -1,116 +0,0 @@
-import base64
-from urllib.parse import urlencode, unquote
-from typing import Dict, Any
-
-from mangum.handlers.abstract_handler import AbstractHandler
-from mangum.types import Response, Request, LambdaEvent, LambdaContext, QueryParams
-
-
-class AwsApiGateway(AbstractHandler):
- def __init__(
- self,
- trigger_event: LambdaEvent,
- trigger_context: LambdaContext,
- api_gateway_base_path: str,
- ):
- super().__init__(trigger_event, trigger_context)
- self.api_gateway_base_path = api_gateway_base_path
-
- @property
- def request(self) -> Request:
- event = self.trigger_event
-
- # See this for more info on headers:
- # https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#apigateway-multivalue-headers-and-parameters
- headers = {}
- # Read headers
- if event.get("headers"):
- headers.update({k.lower(): v for k, v in event.get("headers", {}).items()})
- # Read multiValueHeaders
- # This overrides headers that have the same name
- # That means that multiValue versions of headers take precedence
- # over their plain versions
- if event.get("multiValueHeaders"):
- headers.update(
- {
- k.lower(): ", ".join(v) if isinstance(v, list) else ""
- for k, v in event.get("multiValueHeaders", {}).items()
- }
- )
-
- request_context = event["requestContext"]
-
- source_ip = request_context.get("identity", {}).get("sourceIp")
- path = unquote(self._strip_base_path(event["path"])) if event["path"] else "/"
- http_method = event["httpMethod"]
- query_string = self._encode_query_string()
-
- server_name = headers.get("host", "mangum")
- if ":" not in server_name:
- server_port = headers.get("x-forwarded-port", 80)
- else:
- server_name, server_port = server_name.split(":") # pragma: no cover
- server = (server_name, int(server_port))
- client = (source_ip, 0)
-
- return Request(
- method=http_method,
- headers=[[k.encode(), v.encode()] for k, v in headers.items()],
- path=path,
- scheme=headers.get("x-forwarded-proto", "https"),
- query_string=query_string,
- server=server,
- client=client,
- trigger_event=self.trigger_event,
- trigger_context=self.trigger_context,
- )
-
- def _encode_query_string(self) -> bytes:
- """
- Encodes the queryStringParameters.
- """
-
- params: QueryParams = self.trigger_event.get(
- "multiValueQueryStringParameters", {}
- )
- if not params:
- params = self.trigger_event.get("queryStringParameters", {})
- if not params:
- return b""
- return urlencode(params, doseq=True).encode()
-
- def _strip_base_path(self, path: str) -> str:
- if self.api_gateway_base_path and self.api_gateway_base_path != "/":
- if not self.api_gateway_base_path.startswith("/"):
- self.api_gateway_base_path = f"/{self.api_gateway_base_path}"
- if path.startswith(self.api_gateway_base_path):
- path = path[len(self.api_gateway_base_path) :]
- return path
-
- @property
- def body(self) -> bytes:
- body = self.trigger_event.get("body", b"") or b""
-
- if self.trigger_event.get("isBase64Encoded", False):
- return base64.b64decode(body)
- if not isinstance(body, bytes):
- body = body.encode()
-
- return body
-
- def transform_response(self, response: Response) -> Dict[str, Any]:
- headers, multi_value_headers = self._handle_multi_value_headers(
- response.headers
- )
-
- body, is_base64_encoded = self._handle_base64_response_body(
- response.body, headers
- )
-
- return {
- "statusCode": response.status,
- "headers": headers,
- "multiValueHeaders": multi_value_headers,
- "body": body,
- "isBase64Encoded": is_base64_encoded,
- }
diff --git a/mangum/handlers/aws_cf_lambda_at_edge.py b/mangum/handlers/aws_cf_lambda_at_edge.py
deleted file mode 100644
index def851ef..00000000
--- a/mangum/handlers/aws_cf_lambda_at_edge.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import base64
-from typing import Dict, Any, List
-
-from mangum.handlers.abstract_handler import AbstractHandler
-from mangum.types import Response, Request
-
-
-class AwsCfLambdaAtEdge(AbstractHandler):
- @property
- def request(self) -> Request:
- event = self.trigger_event
-
- cf_request = event["Records"][0]["cf"]["request"]
-
- scheme_header = cf_request["headers"].get("cloudfront-forwarded-proto", [{}])
- scheme = scheme_header[0].get("value", "https")
-
- host_header = cf_request["headers"].get("host", [{}])
- server_name = host_header[0].get("value", "mangum")
- if ":" not in server_name:
- forwarded_port_header = cf_request["headers"].get("x-forwarded-port", [{}])
- server_port = forwarded_port_header[0].get("value", 80)
- else:
- server_name, server_port = server_name.split(":") # pragma: no cover
- server = (server_name, int(server_port))
-
- source_ip = cf_request["clientIp"]
- client = (source_ip, 0)
-
- return Request(
- method=cf_request["method"],
- headers=[
- [k.encode(), v[0]["value"].encode()]
- for k, v in cf_request["headers"].items()
- ],
- path=cf_request["uri"],
- scheme=scheme,
- query_string=cf_request["querystring"].encode(),
- server=server,
- client=client,
- trigger_event=self.trigger_event,
- trigger_context=self.trigger_context,
- )
-
- @property
- def body(self) -> bytes:
- request = self.trigger_event["Records"][0]["cf"]["request"]
- body = request.get("body", {}).get("data", None) or b""
-
- if request.get("body", {}).get("encoding", "") == "base64":
- return base64.b64decode(body)
- if not isinstance(body, bytes):
- body = body.encode()
-
- return body
-
- def transform_response(self, response: Response) -> Dict[str, Any]:
- headers_dict, _ = self._handle_multi_value_headers(response.headers)
- body, is_base64_encoded = self._handle_base64_response_body(
- response.body, headers_dict
- )
-
- # Expand headers to weird list of Dict[str, List[Dict[str, str]]]
- headers_expanded: Dict[str, List[Dict[str, str]]] = {
- key.decode().lower(): [{"key": key.decode().lower(), "value": val.decode()}]
- for key, val in response.headers
- }
- return {
- "status": response.status,
- "headers": headers_expanded,
- "body": body,
- "isBase64Encoded": is_base64_encoded,
- }
diff --git a/mangum/handlers/aws_http_gateway.py b/mangum/handlers/aws_http_gateway.py
deleted file mode 100644
index feed1efb..00000000
--- a/mangum/handlers/aws_http_gateway.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import base64
-import urllib.parse
-from typing import Dict, Any, List, Tuple
-
-from mangum.handlers.aws_api_gateway import AwsApiGateway
-from mangum.types import Response, Request
-
-
-class AwsHttpGateway(AwsApiGateway):
- """
- Handles AWS HTTP Gateway events (v1.0 and v2.0), transforming them into ASGI Scope
- and handling responses
-
- See: https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html # noqa: E501
- """
-
- TYPE = "AWS_HTTP_GATEWAY"
-
- @property
- def event_version(self) -> str:
- return self.trigger_event.get("version", "")
-
- @property
- def request(self) -> Request:
- event = self.trigger_event
-
- headers = {}
- if event.get("headers"):
- headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
-
- request_context = event["requestContext"]
-
- # API Gateway v2
- if self.event_version == "2.0":
- source_ip = request_context["http"]["sourceIp"]
- path = request_context["http"]["path"]
- http_method = request_context["http"]["method"]
- query_string = event.get("rawQueryString", "").encode()
-
- if event.get("cookies"):
- headers["cookie"] = "; ".join(event.get("cookies", []))
-
- # API Gateway v1
- elif self.event_version == "1.0":
- # v1.0 of the HTTP Gateway supports multiValueHeaders
- if event.get("multiValueHeaders"):
- headers.update(
- {
- k.lower(): ", ".join(v) if isinstance(v, list) else ""
- for k, v in event.get("multiValueHeaders", {}).items()
- }
- )
-
- source_ip = request_context.get("identity", {}).get("sourceIp")
- path = event["path"]
- http_method = event["httpMethod"]
- query_string = self._encode_query_string()
- else:
- raise RuntimeError(
- "Unsupported version of HTTP Gateway Spec, only v1.0 and v2.0 are "
- "supported."
- )
-
- server_name = headers.get("host", "mangum")
- if ":" not in server_name:
- server_port = headers.get("x-forwarded-port", 80)
- else:
- server_name, server_port = server_name.split(":") # pragma: no cover
- server = (server_name, int(server_port))
- client = (source_ip, 0)
-
- if not path:
- path = "/"
- else:
- path = self._strip_base_path(path)
-
- return Request(
- method=http_method,
- headers=[[k.encode(), v.encode()] for k, v in headers.items()],
- path=urllib.parse.unquote(path),
- scheme=headers.get("x-forwarded-proto", "https"),
- query_string=query_string,
- server=server,
- client=client,
- trigger_event=self.trigger_event,
- trigger_context=self.trigger_context,
- )
-
- @property
- def body(self) -> bytes:
- body = self.trigger_event.get("body", b"") or b""
-
- if self.trigger_event.get("isBase64Encoded", False):
- return base64.b64decode(body)
- if not isinstance(body, bytes):
- body = body.encode()
-
- return body
-
- def transform_response(self, response: Response) -> Dict[str, Any]:
- """
- This handles some unnecessary magic from AWS
-
- > API Gateway can infer the response format for you
- Boooooo
-
- https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.response
- """
- if self.event_version == "1.0":
- return self.transform_response_v1(response)
- elif self.event_version == "2.0":
- return self.transform_response_v2(response)
- raise RuntimeError( # pragma: no cover
- "Misconfigured event unable to return value, unsupported version."
- )
-
- def transform_response_v1(self, response: Response) -> Dict[str, Any]:
- headers, multi_value_headers = self._handle_multi_value_headers(
- response.headers
- )
-
- body, is_base64_encoded = self._handle_base64_response_body(
- response.body, headers
- )
- return {
- "statusCode": response.status,
- "headers": headers,
- "multiValueHeaders": multi_value_headers,
- "body": body,
- "isBase64Encoded": is_base64_encoded,
- }
-
- def _combine_headers_v2(
- self, input_headers: List[List[bytes]]
- ) -> Tuple[Dict[str, str], List[str]]:
- output_headers: Dict[str, str] = {}
- cookies: List[str] = []
- for key, value in input_headers:
- normalized_key: str = key.decode().lower()
- normalized_value: str = value.decode()
- if normalized_key == "set-cookie":
- cookies.append(normalized_value)
- else:
- if normalized_key in output_headers:
- normalized_value = (
- f"{output_headers[normalized_key]},{normalized_value}"
- )
- output_headers[normalized_key] = normalized_value
- return output_headers, cookies
-
- def transform_response_v2(self, response_in: Response) -> Dict[str, Any]:
- # The API Gateway will infer stuff for us, but we'll just do that inference
- # here and keep the output consistent
-
- headers, cookies = self._combine_headers_v2(response_in.headers)
-
- if "content-type" not in headers and response_in.body is not None:
- headers["content-type"] = "application/json"
-
- body, is_base64_encoded = self._handle_base64_response_body(
- response_in.body, headers
- )
- response_out = {
- "statusCode": response_in.status,
- "body": body,
- "headers": headers or None,
- "cookies": cookies or None,
- "isBase64Encoded": is_base64_encoded,
- }
- return {key: value for key, value in response_out.items() if value is not None}
diff --git a/mangum/handlers/lambda_at_edge.py b/mangum/handlers/lambda_at_edge.py
new file mode 100644
index 00000000..7a91147b
--- /dev/null
+++ b/mangum/handlers/lambda_at_edge.py
@@ -0,0 +1,102 @@
+from typing import Dict, List, Optional
+
+from mangum.handlers.utils import (
+ handle_base64_response_body,
+ handle_multi_value_headers,
+ maybe_encode_body,
+)
+from mangum.types import (
+ HTTPScope,
+ HTTPResponse,
+ LambdaConfig,
+ LambdaEvent,
+ LambdaContext,
+ LambdaHandler,
+)
+
+
+class LambdaAtEdge:
+ @classmethod
+ def infer(
+ cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> Optional[LambdaHandler]:
+ if (
+ "Records" in event
+ and len(event["Records"]) > 0
+ and "cf" in event["Records"][0]
+ ):
+ return cls(event, context, config)
+
+ # FIXME: Since this is the last in the chain it doesn't get coverage by default,
+ # just ignoring it for now.
+ return None # pragma: nocover
+
+ def __init__(
+ self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> None:
+ self.event = event
+ self.context = context
+ self.config = config
+
+ @property
+ def body(self) -> bytes:
+ cf_request_body = self.event["Records"][0]["cf"]["request"].get("body", {})
+ return maybe_encode_body(
+ cf_request_body.get("data"),
+ is_base64=cf_request_body.get("encoding", "") == "base64",
+ )
+
+ @property
+ def scope(self) -> HTTPScope:
+ cf_request = self.event["Records"][0]["cf"]["request"]
+ scheme_header = cf_request["headers"].get("cloudfront-forwarded-proto", [{}])
+ scheme = scheme_header[0].get("value", "https")
+ host_header = cf_request["headers"].get("host", [{}])
+ server_name = host_header[0].get("value", "mangum")
+ if ":" not in server_name:
+ forwarded_port_header = cf_request["headers"].get("x-forwarded-port", [{}])
+ server_port = forwarded_port_header[0].get("value", 80)
+ else:
+ server_name, server_port = server_name.split(":") # pragma: no cover
+
+ server = (server_name, int(server_port))
+ source_ip = cf_request["clientIp"]
+ client = (source_ip, 0)
+ http_method = cf_request["method"]
+
+ return {
+ "type": "http",
+ "method": http_method,
+ "http_version": "1.1",
+ "headers": [
+ [k.encode(), v[0]["value"].encode()]
+ for k, v in cf_request["headers"].items()
+ ],
+ "path": cf_request["uri"],
+ "raw_path": None,
+ "root_path": "",
+ "scheme": scheme,
+ "query_string": cf_request["querystring"].encode(),
+ "server": server,
+ "client": client,
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
+ "aws.event": self.event,
+ "aws.context": self.context,
+ }
+
+ def __call__(self, response: HTTPResponse) -> dict:
+ multi_value_headers, _ = handle_multi_value_headers(response["headers"])
+ response_body, is_base64_encoded = handle_base64_response_body(
+ response["body"], multi_value_headers
+ )
+ finalized_headers: Dict[str, List[Dict[str, str]]] = {
+ key.decode().lower(): [{"key": key.decode().lower(), "value": val.decode()}]
+ for key, val in response["headers"]
+ }
+
+ return {
+ "status": response["status"],
+ "headers": finalized_headers,
+ "body": response_body,
+ "isBase64Encoded": is_base64_encoded,
+ }
diff --git a/mangum/handlers/utils.py b/mangum/handlers/utils.py
new file mode 100644
index 00000000..1ec1dfb6
--- /dev/null
+++ b/mangum/handlers/utils.py
@@ -0,0 +1,90 @@
+import base64
+from typing import Dict, List, Tuple, Union
+from urllib.parse import unquote
+
+from mangum.types import Headers
+
+
+DEFAULT_TEXT_MIME_TYPES = [
+ "text/",
+ "application/json",
+ "application/javascript",
+ "application/xml",
+ "application/vnd.api+json",
+]
+
+
+def maybe_encode_body(body: Union[str, bytes], *, is_base64: bool) -> bytes:
+ body = body or b""
+ if is_base64:
+ body = base64.b64decode(body)
+ elif not isinstance(body, bytes):
+ body = body.encode()
+
+ return body
+
+
+def get_server_and_port(headers: dict) -> Tuple[str, int]:
+ server_name = headers.get("host", "mangum")
+ if ":" not in server_name:
+ server_port = headers.get("x-forwarded-port", 80)
+ else:
+ server_name, server_port = server_name.split(":") # pragma: no cover
+ server = (server_name, int(server_port))
+
+ return server
+
+
+def strip_api_gateway_path(path: str, *, api_gateway_base_path: str) -> str:
+ if not path:
+ return "/"
+
+ if api_gateway_base_path and api_gateway_base_path != "/":
+ if not api_gateway_base_path.startswith("/"):
+ api_gateway_base_path = f"/{api_gateway_base_path}"
+ if path.startswith(api_gateway_base_path):
+ path = path[len(api_gateway_base_path) :]
+
+ return unquote(path)
+
+
+def handle_multi_value_headers(
+ response_headers: Headers,
+) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
+ headers: Dict[str, str] = {}
+ multi_value_headers: Dict[str, List[str]] = {}
+ for key, value in response_headers:
+ lower_key = key.decode().lower()
+ if lower_key in multi_value_headers:
+ multi_value_headers[lower_key].append(value.decode())
+ elif lower_key in headers:
+ # Move existing to multi_value_headers and append current
+ multi_value_headers[lower_key] = [
+ headers[lower_key],
+ value.decode(),
+ ]
+ del headers[lower_key]
+ else:
+ headers[lower_key] = value.decode()
+ return headers, multi_value_headers
+
+
+def handle_base64_response_body(
+ body: bytes, headers: Dict[str, str]
+) -> Tuple[str, bool]:
+ is_base64_encoded = False
+ output_body = ""
+ if body != b"":
+ for text_mime_type in DEFAULT_TEXT_MIME_TYPES:
+ if text_mime_type in headers.get("content-type", ""):
+ try:
+ output_body = body.decode()
+ except UnicodeDecodeError:
+ output_body = base64.b64encode(body).decode()
+ is_base64_encoded = True
+ break
+ else:
+ output_body = base64.b64encode(body).decode()
+ is_base64_encoded = True
+
+ return output_body, is_base64_encoded
diff --git a/mangum/protocols/http.py b/mangum/protocols/http.py
index 33db249e..fd452f5a 100644
--- a/mangum/protocols/http.py
+++ b/mangum/protocols/http.py
@@ -1,19 +1,23 @@
-import enum
import asyncio
-from typing import Optional
+import enum
import logging
from io import BytesIO
-from dataclasses import dataclass
-from .. import Response, Request
-from ..types import ASGIApp, Message
-from ..exceptions import UnexpectedMessage
+
+from mangum.types import (
+ ASGIApp,
+ ASGIReceiveEvent,
+ ASGISendEvent,
+ HTTPDisconnectEvent,
+ HTTPScope,
+ HTTPResponse,
+)
+from mangum.exceptions import UnexpectedMessage
class HTTPCycleState(enum.Enum):
"""
The state of the ASGI `http` connection.
-
* **REQUEST** - Initial state. The ASGI application instance will be run with the
connection scope containing the `http` type.
* **RESPONSE** - The `http.response.start` event has been sent by the application.
@@ -30,55 +34,38 @@ class HTTPCycleState(enum.Enum):
COMPLETE = enum.auto()
-@dataclass
class HTTPCycle:
- """
- Manages the application cycle for an ASGI `http` connection.
-
- * **request** - A request object containing the event and context for the connection
- scope used to run the ASGI application instance.
- * **state** - An enumerated `HTTPCycleState` type that indicates the state of the
- ASGI connection.
- * **app_queue** - An asyncio queue (FIFO) containing messages to be received by the
- application.
- * **response** - A dictionary containing the response data to return in AWS Lambda.
- """
-
- request: Request
- state: HTTPCycleState = HTTPCycleState.REQUEST
- response: Optional[Response] = None
-
- def __post_init__(self) -> None:
- self.logger: logging.Logger = logging.getLogger("mangum.http")
- self.loop = asyncio.get_event_loop()
- self.app_queue: asyncio.Queue[Message] = asyncio.Queue()
- self.body: BytesIO = BytesIO()
-
- def __call__(self, app: ASGIApp, initial_body: bytes) -> Response:
- self.logger.debug("HTTP cycle starting.")
+ def __init__(self, scope: HTTPScope, body: bytes) -> None:
+ self.scope = scope
+ self.buffer = BytesIO()
+ self.state = HTTPCycleState.REQUEST
+ self.logger = logging.getLogger("mangum.http")
+ self.app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
self.app_queue.put_nowait(
- {"type": "http.request", "body": initial_body, "more_body": False}
+ {
+ "type": "http.request",
+ "body": body,
+ "more_body": False,
+ }
)
+
+ def __call__(self, app: ASGIApp) -> HTTPResponse:
asgi_instance = self.run(app)
- asgi_task = self.loop.create_task(asgi_instance)
- self.loop.run_until_complete(asgi_task)
+ loop = asyncio.get_event_loop()
+ asgi_task = loop.create_task(asgi_instance)
+ loop.run_until_complete(asgi_task)
- if self.response is None: # pragma: nocover
- self.response = Response(
- status=500,
- body=b"Internal Server Error",
- headers=[[b"content-type", b"text/plain; charset=utf-8"]],
- )
- return self.response
+ return {
+ "status": self.status,
+ "headers": self.headers,
+ "body": self.body,
+ }
async def run(self, app: ASGIApp) -> None:
- """
- Calls the application with the `http` connection scope.
- """
try:
- await app(self.request.scope, self.receive, self.send)
- except BaseException as exc:
- self.logger.error("Exception in 'http' protocol.", exc_info=exc)
+ await app(self.scope, self.receive, self.send)
+ except BaseException:
+ self.logger.exception("An error occurred running the application.")
if self.state is HTTPCycleState.REQUEST:
await self.send(
{
@@ -88,63 +75,48 @@ async def run(self, app: ASGIApp) -> None:
}
)
await self.send(
- {"type": "http.response.body", "body": b"Internal Server Error"}
+ {
+ "type": "http.response.body",
+ "body": b"Internal Server Error",
+ "more_body": False,
+ }
)
elif self.state is not HTTPCycleState.COMPLETE:
- self.response = Response(
- status=500,
- body=b"Internal Server Error",
- headers=[[b"content-type", b"text/plain; charset=utf-8"]],
- )
+ self.status = 500
+ self.body = b"Internal Server Error"
+ self.headers = [[b"content-type", b"text/plain; charset=utf-8"]]
- async def receive(self) -> Message:
- """
- Awaited by the application to receive ASGI `http` events.
- """
+ async def receive(self) -> ASGIReceiveEvent:
return await self.app_queue.get() # pragma: no cover
- async def send(self, message: Message) -> None:
- """
- Awaited by the application to send ASGI `http` events.
- """
- message_type = message["type"]
-
+ async def send(self, message: ASGISendEvent) -> None:
if (
self.state is HTTPCycleState.REQUEST
- and message_type == "http.response.start"
+ and message["type"] == "http.response.start"
):
- if self.response is None:
- self.response = Response(
- status=message["status"],
- headers=message.get("headers", []),
- body=b"",
- )
+ self.status = message["status"]
+ self.headers = message.get("headers", [])
self.state = HTTPCycleState.RESPONSE
elif (
self.state is HTTPCycleState.RESPONSE
- and message_type == "http.response.body"
+ and message["type"] == "http.response.body"
):
+
body = message.get("body", b"")
more_body = message.get("more_body", False)
-
- # The body must be completely read before returning the response.
- self.body.write(body)
-
- if not more_body and self.response is not None:
- body = self.body.getvalue()
- self.body.close()
- self.response.body = body
+ self.buffer.write(body)
+ if not more_body:
+ self.body = self.buffer.getvalue()
+ self.buffer.close()
self.state = HTTPCycleState.COMPLETE
- await self.app_queue.put({"type": "http.disconnect"})
+ await self.app_queue.put(HTTPDisconnectEvent(type="http.disconnect"))
self.logger.info(
"%s %s %s",
- self.request.method,
- self.request.path,
- self.response.status,
+ self.scope["method"],
+ self.scope["path"],
+ self.status,
)
else:
- raise UnexpectedMessage(
- f"{self.state}: Unexpected '{message_type}' event received."
- )
+ raise UnexpectedMessage(f"Unexpected {message['type']}")
diff --git a/mangum/protocols/lifespan.py b/mangum/protocols/lifespan.py
index 4ceb91dd..f1f2a6e2 100644
--- a/mangum/protocols/lifespan.py
+++ b/mangum/protocols/lifespan.py
@@ -1,12 +1,19 @@
import asyncio
+import enum
import logging
from types import TracebackType
from typing import Optional, Type
-import enum
-from dataclasses import dataclass
-from ..types import ASGIApp, Message
-from ..exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage
+
+from mangum.types import (
+ ASGIApp,
+ LifespanMode,
+ ASGIReceiveEvent,
+ ASGISendEvent,
+ LifespanShutdownEvent,
+ LifespanStartupEvent,
+)
+from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage
class LifespanCycleState(enum.Enum):
@@ -33,7 +40,6 @@ class LifespanCycleState(enum.Enum):
UNSUPPORTED = enum.auto()
-@dataclass
class LifespanCycle:
"""
Manages the application cycle for an ASGI `lifespan` connection.
@@ -54,22 +60,19 @@ class LifespanCycle:
shutdown flow.
"""
- app: ASGIApp
- lifespan: str
- state: LifespanCycleState = LifespanCycleState.CONNECTING
- exception: Optional[BaseException] = None
-
- def __post_init__(self) -> None:
- self.logger = logging.getLogger("mangum.lifespan")
+ def __init__(self, app: ASGIApp, lifespan: LifespanMode) -> None:
+ self.app = app
+ self.lifespan = lifespan
+ self.state: LifespanCycleState = LifespanCycleState.CONNECTING
+ self.exception: Optional[BaseException] = None
self.loop = asyncio.get_event_loop()
- self.app_queue: asyncio.Queue[Message] = asyncio.Queue()
+ self.app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue()
self.startup_event: asyncio.Event = asyncio.Event()
self.shutdown_event: asyncio.Event = asyncio.Event()
+ self.logger = logging.getLogger("mangum.lifespan")
def __enter__(self) -> None:
- """
- Runs the event loop for application startup.
- """
+ """Runs the event loop for application startup."""
self.loop.create_task(self.run())
self.loop.run_until_complete(self.startup())
@@ -79,17 +82,17 @@ def __exit__(
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
- """
- Runs the event loop for application shutdown.
- """
+ """Runs the event loop for application shutdown."""
self.loop.run_until_complete(self.shutdown())
async def run(self) -> None:
- """
- Calls the application with the `lifespan` connection scope.
- """
+ """Calls the application with the `lifespan` connection scope."""
try:
- await self.app({"type": "lifespan"}, self.receive, self.send)
+ await self.app(
+ {"type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}},
+ self.receive,
+ self.send,
+ )
except LifespanUnsupported:
self.logger.info("ASGI 'lifespan' protocol appears unsupported.")
except (LifespanFailure, UnexpectedMessage) as exc:
@@ -100,10 +103,8 @@ async def run(self) -> None:
self.startup_event.set()
self.shutdown_event.set()
- async def receive(self) -> Message:
- """
- Awaited by the application to receive ASGI `lifespan` events.
- """
+ async def receive(self) -> ASGIReceiveEvent:
+ """Awaited by the application to receive ASGI `lifespan` events."""
if self.state is LifespanCycleState.CONNECTING:
# Connection established. The next event returned by the queue will be
@@ -120,17 +121,14 @@ async def receive(self) -> Message:
return await self.app_queue.get()
- async def send(self, message: Message) -> None:
- """
- Awaited by the application to send ASGI `lifespan` events.
- """
+ async def send(self, message: ASGISendEvent) -> None:
+ """Awaited by the application to send ASGI `lifespan` events."""
message_type = message["type"]
self.logger.info(
"%s: '%s' event received from application.", self.state, message_type
)
if self.state is LifespanCycleState.CONNECTING:
-
if self.lifespan == "on":
raise LifespanFailure(
"Lifespan connection failed during startup and lifespan is 'on'."
@@ -156,8 +154,8 @@ async def send(self, message: Message) -> None:
elif message_type == "lifespan.startup.failed":
self.state = LifespanCycleState.FAILED
self.startup_event.set()
- message = message.get("message", "")
- raise LifespanFailure(f"Lifespan startup failure. {message}")
+ message_value = message.get("message", "")
+ raise LifespanFailure(f"Lifespan startup failure. {message_value}")
elif self.state is LifespanCycleState.SHUTDOWN:
if message_type == "lifespan.shutdown.complete":
@@ -165,15 +163,13 @@ async def send(self, message: Message) -> None:
elif message_type == "lifespan.shutdown.failed":
self.state = LifespanCycleState.FAILED
self.shutdown_event.set()
- message = message.get("message", "")
- raise LifespanFailure(f"Lifespan shutdown failure. {message}")
+ message_value = message.get("message", "")
+ raise LifespanFailure(f"Lifespan shutdown failure. {message_value}")
async def startup(self) -> None:
- """
- Pushes the `lifespan` startup event to application queue and handles errors.
- """
+ """Pushes the `lifespan` startup event to the queue and handles errors."""
self.logger.info("Waiting for application startup.")
- await self.app_queue.put({"type": "lifespan.startup"})
+ await self.app_queue.put(LifespanStartupEvent(type="lifespan.startup"))
await self.startup_event.wait()
if self.state is LifespanCycleState.FAILED:
raise LifespanFailure(self.exception)
@@ -184,11 +180,9 @@ async def startup(self) -> None:
self.logger.info("Application startup failed.")
async def shutdown(self) -> None:
- """
- Pushes the `lifespan` shutdown event to application queue and handles errors.
- """
+ """Pushes the `lifespan` shutdown event to the queue and handles errors."""
self.logger.info("Waiting for application shutdown.")
- await self.app_queue.put({"type": "lifespan.shutdown"})
+ await self.app_queue.put(LifespanShutdownEvent(type="lifespan.shutdown"))
await self.shutdown_event.wait()
if self.state is LifespanCycleState.FAILED:
raise LifespanFailure(self.exception)
diff --git a/mangum/types.py b/mangum/types.py
index 1f8b8b14..bcf4a0d0 100644
--- a/mangum/types.py
+++ b/mangum/types.py
@@ -1,4 +1,5 @@
-from dataclasses import dataclass
+from __future__ import annotations
+
from typing import (
List,
Tuple,
@@ -11,21 +12,11 @@
Awaitable,
Callable,
)
-from typing_extensions import Protocol, TypeAlias
-
-QueryParams: TypeAlias = MutableMapping[str, Union[str, Sequence[str]]]
-Message: TypeAlias = MutableMapping[str, Any]
-Scope: TypeAlias = MutableMapping[str, Any]
-Receive: TypeAlias = Callable[[], Awaitable[Message]]
-Send: TypeAlias = Callable[[Message], Awaitable[None]]
-
-
-class ASGIApp(Protocol):
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- ... # pragma: no cover
+from typing_extensions import Literal, Protocol, TypedDict, TypeAlias
LambdaEvent = Dict[str, Any]
+QueryParams: TypeAlias = MutableMapping[str, Union[str, Sequence[str]]]
class LambdaCognitoIdentity(Protocol):
@@ -103,65 +94,150 @@ def get_remaining_time_in_millis(self) -> int:
... # pragma: no cover
-@dataclass
-class BaseRequest:
- """
- A holder for an ASGI scope. Contains additional meta from the event that triggered
- the Lambda function.
- """
+Headers: TypeAlias = List[List[bytes]]
- headers: List[List[bytes]]
- path: str
- scheme: str
- query_string: bytes
- server: Tuple[str, int]
- client: Tuple[str, int]
- # Invocation event
- trigger_event: Dict[str, Any]
- trigger_context: Union["LambdaContext", Dict[str, Any]]
+class HTTPRequestEvent(TypedDict):
+ type: Literal["http.request"]
+ body: bytes
+ more_body: bool
- raw_path: Optional[str] = None
- root_path: str = ""
- @property
- def scope(self) -> Scope:
- return {
- "http_version": "1.1",
- "headers": self.headers,
- "path": self.path,
- "raw_path": self.raw_path,
- "root_path": self.root_path,
- "scheme": self.scheme,
- "query_string": self.query_string,
- "server": self.server,
- "client": self.client,
- "asgi": {"version": "3.0"},
- "aws.event": self.trigger_event,
- "aws.context": self.trigger_context,
- }
-
-
-@dataclass
-class Request(BaseRequest):
- """
- A holder for an ASGI scope. Specific for usage with HTTP connections.
+class HTTPDisconnectEvent(TypedDict):
+ type: Literal["http.disconnect"]
- https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
- """
- type: str = "http"
- method: str = "GET"
+class HTTPResponseStartEvent(TypedDict):
+ type: Literal["http.response.start"]
+ status: int
+ headers: Headers
- @property
- def scope(self) -> Scope:
- scope = super().scope
- scope.update({"type": self.type, "method": self.method})
- return scope
+class HTTPResponseBodyEvent(TypedDict):
+ type: Literal["http.response.body"]
+ body: bytes
+ more_body: bool
+
+
+class LifespanStartupEvent(TypedDict):
+ type: Literal["lifespan.startup"]
+
+
+class LifespanStartupCompleteEvent(TypedDict):
+ type: Literal["lifespan.startup.complete"]
+
+
+class LifespanStartupFailedEvent(TypedDict):
+ type: Literal["lifespan.startup.failed"]
+ message: str
+
+
+class LifespanShutdownEvent(TypedDict):
+ type: Literal["lifespan.shutdown"]
+
+
+class LifespanShutdownCompleteEvent(TypedDict):
+ type: Literal["lifespan.shutdown.complete"]
+
+
+class LifespanShutdownFailedEvent(TypedDict):
+ type: Literal["lifespan.shutdown.failed"]
+ message: str
+
+
+ASGIReceiveEvent: TypeAlias = Union[
+ HTTPRequestEvent,
+ HTTPDisconnectEvent,
+ LifespanStartupEvent,
+ LifespanShutdownEvent,
+]
+
+ASGISendEvent: TypeAlias = Union[
+ HTTPResponseStartEvent,
+ HTTPResponseBodyEvent,
+ HTTPDisconnectEvent,
+ LifespanStartupCompleteEvent,
+ LifespanStartupFailedEvent,
+ LifespanShutdownCompleteEvent,
+ LifespanShutdownFailedEvent,
+]
-@dataclass
-class Response:
+
+ASGIReceive: TypeAlias = Callable[[], Awaitable[ASGIReceiveEvent]]
+ASGISend: TypeAlias = Callable[[ASGISendEvent], Awaitable[None]]
+
+
+class ASGISpec(TypedDict):
+ spec_version: Literal["2.0"]
+ version: Literal["3.0"]
+
+
+HTTPScope = TypedDict(
+ "HTTPScope",
+ {
+ "type": Literal["http"],
+ "asgi": ASGISpec,
+ "http_version": Literal["1.1"],
+ "scheme": str,
+ "method": str,
+ "path": str,
+ "raw_path": None,
+ "root_path": Literal[""],
+ "query_string": bytes,
+ "headers": Headers,
+ "client": Tuple[str, int],
+ "server": Tuple[str, int],
+ "aws.event": LambdaEvent,
+ "aws.context": LambdaContext,
+ },
+)
+
+
+class LifespanScope(TypedDict):
+ type: Literal["lifespan"]
+ asgi: ASGISpec
+
+
+LifespanMode: TypeAlias = Literal["auto", "on", "off"]
+Scope: TypeAlias = Union[HTTPScope, LifespanScope]
+
+
+class ASGIApp(Protocol):
+ async def __call__(
+ self, scope: Scope, receive: ASGIReceive, send: ASGISend
+ ) -> None:
+ ... # pragma: no cover
+
+
+class HTTPResponse(TypedDict):
status: int
- headers: List[List[bytes]] # ex: [[b'content-type', b'text/plain; charset=utf-8']]
+ headers: Headers
body: bytes
+
+
+class LambdaConfig(TypedDict):
+ api_gateway_base_path: str
+
+
+class LambdaHandler(Protocol):
+ @classmethod
+ def infer(
+ cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> Optional[LambdaHandler]:
+ ... # pragma: no cover
+
+ def __init__(
+ self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> None:
+ ... # pragma: no cover
+
+ @property
+ def body(self) -> bytes:
+ ... # pragma: no cover
+
+ @property
+ def scope(self) -> HTTPScope:
+ ... # pragma: no cover
+
+ def __call__(self, response: HTTPResponse) -> dict:
+ ... # pragma: no cover
diff --git a/requirements.txt b/requirements.txt
index d7ea423c..64398c30 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,11 +5,9 @@ pytest-cov
black
flake8
starlette
-quart; python_version >= '3.7'
+quart
mypy
brotli
brotli-asgi
-types-dataclasses; python_version < '3.7' # For mypy
-# Docs
mkdocs
mkdocs-material
diff --git a/scripts/lint b/scripts/lint
index de557e7d..5dd72263 100755
--- a/scripts/lint
+++ b/scripts/lint
@@ -8,5 +8,5 @@ fi
set -x
${PREFIX}black mangum tests --check
-${PREFIX}mypy mangum --disallow-untyped-defs --ignore-missing-imports
+${PREFIX}mypy mangum
${PREFIX}flake8 mangum tests
diff --git a/setup.cfg b/setup.cfg
index 78c5ea9c..f3777857 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -11,3 +11,4 @@ disallow_untyped_calls = True
disallow_incomplete_defs = True
disallow_untyped_decorators = True
ignore_missing_imports = True
+show_error_codes = True
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 8506ec49..7641c9eb 100644
--- a/setup.py
+++ b/setup.py
@@ -11,10 +11,10 @@ def get_long_description():
packages=find_packages(exclude=["tests*"]),
license="MIT",
url="https://github.com/jordaneremieff/mangum",
- description="AWS Lambda & API Gateway support for ASGI",
+ description="AWS Lambda support for ASGI applications",
long_description=get_long_description(),
python_requires=">=3.6",
- install_requires=["typing_extensions", "dataclasses; python_version < '3.7'"],
+ install_requires=["typing_extensions"],
package_data={"mangum": ["py.typed"]},
long_description_content_type="text/markdown",
author="Jordan Eremieff",
@@ -23,7 +23,6 @@ def get_long_description():
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python",
- "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
diff --git a/tests/handlers/test_abstract_handler.py b/tests/handlers/test_abstract_handler.py
deleted file mode 100644
index b38366e3..00000000
--- a/tests/handlers/test_abstract_handler.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import pytest
-
-from mangum.handlers.abstract_handler import AbstractHandler
-
-
-def test_abstract_handler_unkown_event():
- """
- Test an unknown event, ensure it fails in a consistent way
- """
- example_event = {"hello": "world", "foo": "bar"}
- example_context = {}
- with pytest.raises(TypeError):
- AbstractHandler.from_trigger(example_event, example_context)
diff --git a/tests/handlers/test_aws_alb.py b/tests/handlers/test_alb.py
similarity index 98%
rename from tests/handlers/test_aws_alb.py
rename to tests/handlers/test_alb.py
index 105ae81c..6213088d 100644
--- a/tests/handlers/test_aws_alb.py
+++ b/tests/handlers/test_alb.py
@@ -8,7 +8,7 @@
import pytest
from mangum import Mangum
-from mangum.handlers.aws_alb import AwsAlb
+from mangum.handlers.alb import ALB
def get_mock_aws_alb_event(
@@ -199,15 +199,15 @@ def test_aws_alb_scope_real(
multi_value_headers,
)
example_context = {}
- handler = AwsAlb(event, example_context)
+ handler = ALB(event, example_context, {"api_gateway_base_path": "/"})
scope_path = path
if scope_path == "":
scope_path = "/"
assert type(handler.body) == bytes
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": event,
"client": ("72.12.164.125", 0),
diff --git a/tests/handlers/test_aws_api_gateway.py b/tests/handlers/test_api_gateway.py
similarity index 96%
rename from tests/handlers/test_aws_api_gateway.py
rename to tests/handlers/test_api_gateway.py
index d87b669b..4c7cf0a6 100644
--- a/tests/handlers/test_aws_api_gateway.py
+++ b/tests/handlers/test_api_gateway.py
@@ -3,7 +3,7 @@
import pytest
from mangum import Mangum
-from mangum.handlers.aws_api_gateway import AwsApiGateway
+from mangum.handlers.api_gateway import APIGateway
def get_mock_aws_api_gateway_event(
@@ -103,11 +103,11 @@ def test_aws_api_gateway_scope_basic():
"isBase64Encoded": False,
}
example_context = {}
- handler = AwsApiGateway(example_event, example_context, "/")
+ handler = APIGateway(example_event, example_context, {"api_gateway_base_path": "/"})
assert type(handler.body) == bytes
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": example_event,
"client": (None, 0),
@@ -200,14 +200,14 @@ def test_aws_api_gateway_scope_real(
method, path, multi_value_query_parameters, req_body, body_base64_encoded
)
example_context = {}
- handler = AwsApiGateway(event, example_context, "/")
+ handler = APIGateway(event, example_context, {"api_gateway_base_path": "/"})
scope_path = path
if scope_path == "":
scope_path = "/"
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": event,
"client": ("192.168.100.1", 0),
diff --git a/tests/handlers/test_custom.py b/tests/handlers/test_custom.py
new file mode 100644
index 00000000..24286ce3
--- /dev/null
+++ b/tests/handlers/test_custom.py
@@ -0,0 +1,77 @@
+from typing import Optional
+
+from mangum.types import (
+ HTTPScope,
+ Headers,
+ LambdaConfig,
+ LambdaContext,
+ LambdaEvent,
+ LambdaHandler,
+)
+
+
+class CustomHandler:
+ @classmethod
+ def infer(
+ cls, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> Optional[LambdaHandler]:
+ if "my-custom-key" in event:
+ return cls(event, context, config)
+
+ return None
+
+ def __init__(
+ self, event: LambdaEvent, context: LambdaContext, config: LambdaConfig
+ ) -> None:
+ self.event = event
+ self.context = context
+ self.config = config
+
+ @property
+ def body(self) -> bytes:
+ return b"My request body"
+
+ @property
+ def scope(self) -> HTTPScope:
+ headers = {}
+ return {
+ "type": "http",
+ "http_version": "1.1",
+ "method": "GET",
+ "headers": [[k.encode(), v.encode()] for k, v in headers.items()],
+ "path": "/",
+ "raw_path": None,
+ "root_path": "",
+ "scheme": "https",
+ "query_string": b"",
+ "server": ("mangum", 8080),
+ "client": ("127.0.0.1", 0),
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
+ "aws.event": self.event,
+ "aws.context": self.context,
+ }
+
+ def __call__(self, *, status: int, headers: Headers, body: bytes) -> dict:
+ return {"statusCode": status, "headers": {}, "body": body.decode()}
+
+
+def test_custom_handler():
+ event = {"my-custom-key": 1}
+ handler = CustomHandler(event, {}, {"api_gateway_base_path": "/"})
+ assert type(handler.body) == bytes
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
+ "aws.context": {},
+ "aws.event": event,
+ "client": ("127.0.0.1", 0),
+ "headers": [],
+ "http_version": "1.1",
+ "method": "GET",
+ "path": "/",
+ "query_string": b"",
+ "raw_path": None,
+ "root_path": "",
+ "scheme": "https",
+ "server": ("mangum", 8080),
+ "type": "http",
+ }
diff --git a/tests/handlers/test_aws_http_gateway.py b/tests/handlers/test_http_gateway.py
similarity index 93%
rename from tests/handlers/test_aws_http_gateway.py
rename to tests/handlers/test_http_gateway.py
index 2475014d..df428351 100644
--- a/tests/handlers/test_aws_http_gateway.py
+++ b/tests/handlers/test_http_gateway.py
@@ -3,7 +3,7 @@
import pytest
from mangum import Mangum
-from mangum.handlers.aws_http_gateway import AwsHttpGateway
+from mangum.handlers.api_gateway import HTTPGateway
def get_mock_aws_http_gateway_event_v1(
@@ -193,12 +193,15 @@ def test_aws_http_gateway_scope_basic_v1():
"body": "Hello from Lambda!",
"isBase64Encoded": False,
}
+
example_context = {}
- handler = AwsHttpGateway(example_event, example_context, "/")
+ handler = HTTPGateway(
+ example_event, example_context, {"api_gateway_base_path": "/"}
+ )
assert type(handler.body) == bytes
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": example_event,
"client": ("IP", 0),
@@ -225,8 +228,10 @@ def test_aws_http_gateway_scope_v1_only_non_multi_headers():
)
del example_event["multiValueQueryStringParameters"]
example_context = {}
- handler = AwsHttpGateway(example_event, example_context, "/")
- assert handler.request.scope["query_string"] == b"hello=world"
+ handler = HTTPGateway(
+ example_event, example_context, {"api_gateway_base_path": "/"}
+ )
+ assert handler.scope["query_string"] == b"hello=world"
def test_aws_http_gateway_scope_v1_no_headers():
@@ -239,8 +244,10 @@ def test_aws_http_gateway_scope_v1_no_headers():
del example_event["multiValueQueryStringParameters"]
del example_event["queryStringParameters"]
example_context = {}
- handler = AwsHttpGateway(example_event, example_context, "/")
- assert handler.request.scope["query_string"] == b""
+ handler = HTTPGateway(
+ example_event, example_context, {"api_gateway_base_path": "/"}
+ )
+ assert handler.scope["query_string"] == b""
def test_aws_http_gateway_scope_basic_v2():
@@ -297,11 +304,13 @@ def test_aws_http_gateway_scope_basic_v2():
"stageVariables": {"stageVariable1": "value1", "stageVariable2": "value2"},
}
example_context = {}
- handler = AwsHttpGateway(example_event, example_context, "/")
+ handler = HTTPGateway(
+ example_event, example_context, {"api_gateway_base_path": "/"}
+ )
assert type(handler.body) == bytes
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": example_event,
"client": ("IP", 0),
@@ -322,21 +331,6 @@ def test_aws_http_gateway_scope_basic_v2():
}
-def test_aws_http_gateway_scope_bad_version():
- """
- Set a version we don't support
-
- Version is the only thing that is different here, we should be checking that
- specifically
- """
- example_event = get_mock_aws_http_gateway_event_v2("GET", "/test", {}, None, False)
- example_event["version"] = "9001.1"
- example_context = {}
- handler = AwsHttpGateway(example_event, example_context, "/")
- with pytest.raises(RuntimeError):
- handler.request.scope
-
-
@pytest.mark.parametrize(
"method,path,query_parameters,req_body,body_base64_encoded,query_string,scope_body",
[
@@ -369,14 +363,14 @@ def test_aws_http_gateway_scope_real_v1(
method, path, query_parameters, req_body, body_base64_encoded
)
example_context = {}
- handler = AwsHttpGateway(event, example_context, "/")
+ handler = HTTPGateway(event, example_context, {"api_gateway_base_path": "/"})
scope_path = path
if scope_path == "":
scope_path = "/"
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": event,
"client": ("192.168.100.1", 0),
@@ -435,14 +429,14 @@ def test_aws_http_gateway_scope_real_v2(
method, path, query_parameters, req_body, body_base64_encoded
)
example_context = {}
- handler = AwsHttpGateway(event, example_context, "/")
+ handler = HTTPGateway(event, example_context, {"api_gateway_base_path": "/"})
scope_path = path
if scope_path == "":
scope_path = "/"
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": event,
"client": ("192.168.100.1", 0),
diff --git a/tests/handlers/test_aws_cf_lambda_at_edge.py b/tests/handlers/test_lambda_at_edge.py
similarity index 95%
rename from tests/handlers/test_aws_cf_lambda_at_edge.py
rename to tests/handlers/test_lambda_at_edge.py
index f1910ed1..47a53f4e 100644
--- a/tests/handlers/test_aws_cf_lambda_at_edge.py
+++ b/tests/handlers/test_lambda_at_edge.py
@@ -3,7 +3,7 @@
import pytest
from mangum import Mangum
-from mangum.handlers.aws_cf_lambda_at_edge import AwsCfLambdaAtEdge
+from mangum.handlers.lambda_at_edge import LambdaAtEdge
def mock_lambda_at_edge_event(
@@ -134,11 +134,13 @@ def test_aws_cf_lambda_at_edge_scope_basic():
]
}
example_context = {}
- handler = AwsCfLambdaAtEdge(example_event, example_context)
+ handler = LambdaAtEdge(
+ example_event, example_context, {"api_gateway_base_path": "/"}
+ )
assert type(handler.body) == bytes
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": example_event,
"client": ("203.0.113.178", 0),
@@ -223,10 +225,10 @@ def test_aws_api_gateway_scope_real(
method, path, multi_value_query_parameters, req_body, body_base64_encoded
)
example_context = {}
- handler = AwsCfLambdaAtEdge(event, example_context)
+ handler = LambdaAtEdge(event, example_context, {"api_gateway_base_path": "/"})
- assert handler.request.scope == {
- "asgi": {"version": "3.0"},
+ assert handler.scope == {
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": event,
"client": ("192.168.100.1", 0),
diff --git a/tests/test_http.py b/tests/test_http.py
index 825b630e..41798055 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -22,7 +22,7 @@
def test_http_response(mock_aws_api_gateway_event) -> None:
async def app(scope, receive, send):
assert scope == {
- "asgi": {"version": "3.0"},
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": {
"body": None,
@@ -274,7 +274,7 @@ async def app(scope, receive, send):
def test_set_cookies_v2(mock_http_api_event_v2) -> None:
async def app(scope, receive, send):
assert scope == {
- "asgi": {"version": "3.0"},
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": {
"version": "2.0",
@@ -391,7 +391,7 @@ async def app(scope, receive, send):
def test_set_cookies_v1(mock_http_api_event_v1) -> None:
async def app(scope, receive, send):
assert scope == {
- "asgi": {"version": "3.0"},
+ "asgi": {"version": "3.0", "spec_version": "2.0"},
"aws.context": {},
"aws.event": {
"version": "1.0",