diff --git a/mangum/adapter.py b/mangum/adapter.py index efb2b16d..cabfc954 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, InitVar from contextlib import ExitStack -from mangum.types import ASGIApp +from mangum.types import ASGIApp, Scope from mangum.protocols.lifespan import LifespanCycle from mangum.protocols.http import HTTPCycle from mangum.exceptions import ConfigurationError @@ -91,81 +91,6 @@ def __call__(self, event: dict, context: "LambdaContext") -> dict: ) stack.enter_context(lifespan_cycle) - request_context = event["requestContext"] - - if event.get("multiValueHeaders"): - headers = { - k.lower(): ", ".join(v) if isinstance(v, list) else "" - for k, v in event.get("multiValueHeaders", {}).items() - } - elif event.get("headers"): - headers = {k.lower(): v for k, v in event.get("headers", {}).items()} - else: - headers = {} - - # API Gateway v2 - if event.get("version") == "2.0": - source_ip = request_context["http"]["sourceIp"] - path = request_context["http"]["path"] - http_method = request_context["http"]["method"] - query_string = event.get("rawQueryString", "").encode() - - if event.get("cookies"): - headers["cookie"] = "; ".join(event.get("cookies", [])) - - # API Gateway v1 / ELB - else: - if "elb" in request_context: - # NOTE: trust only the most right side value - source_ip = headers.get("x-forwarded-for", "").split(", ")[-1] - else: - source_ip = request_context.get("identity", {}).get("sourceIp") - - path = event["path"] - http_method = event["httpMethod"] - - if event.get("multiValueQueryStringParameters"): - query_string = urllib.parse.urlencode( - event.get("multiValueQueryStringParameters", {}), doseq=True - ).encode() - elif event.get("queryStringParameters"): - query_string = urllib.parse.urlencode( - event.get("queryStringParameters", {}) - ).encode() - else: - query_string = b"" - - server_name = headers.get("host", "mangum") - if ":" not in server_name: - server_port = headers.get("x-forwarded-port", 80) - else: - server_name, server_port = server_name.split(":") # pragma: no cover - server = (server_name, int(server_port)) - client = (source_ip, 0) - - if not path: # pragma: no cover - path = "/" - elif self.api_gateway_base_path: - if path.startswith(self.api_gateway_base_path): - path = path[len(self.api_gateway_base_path) :] - - scope = { - "type": "http", - "http_version": "1.1", - "method": http_method, - "headers": [[k.encode(), v.encode()] for k, v in headers.items()], - "path": urllib.parse.unquote(path), - "raw_path": None, - "root_path": "", - "scheme": headers.get("x-forwarded-proto", "https"), - "query_string": query_string, - "server": server, - "client": client, - "asgi": {"version": "3.0"}, - "aws.event": event, - "aws.context": context, - } - is_binary = event.get("isBase64Encoded", False) initial_body = event.get("body") or b"" if is_binary: @@ -173,7 +98,95 @@ def __call__(self, event: dict, context: "LambdaContext") -> dict: elif not isinstance(initial_body, bytes): initial_body = initial_body.encode() + scope = self._create_scope(event, context) http_cycle = HTTPCycle(scope, text_mime_types=self.text_mime_types) response = http_cycle(self.app, initial_body) return response + + def _create_scope(self, event: dict, context: "LambdaContext") -> Scope: + """Creates a scope object according to ASGI specification from a Lambda Event. + + https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope + + The event comes from various sources: AWS ALB, AWS API Gateway of different + versions and configurations(multivalue header, etc). + Thus, some heuristics is applied to guess an event type. + + """ + request_context = event["requestContext"] + + if event.get("multiValueHeaders"): + headers = { + k.lower(): ", ".join(v) if isinstance(v, list) else "" + for k, v in event.get("multiValueHeaders", {}).items() + } + elif event.get("headers"): + headers = {k.lower(): v for k, v in event.get("headers", {}).items()} + else: + headers = {} + + # API Gateway v2 + if event.get("version") == "2.0": + source_ip = request_context["http"]["sourceIp"] + path = request_context["http"]["path"] + http_method = request_context["http"]["method"] + query_string = event.get("rawQueryString", "").encode() + + if event.get("cookies"): + headers["cookie"] = "; ".join(event.get("cookies", [])) + + # API Gateway v1 / ELB + else: + if "elb" in request_context: + # NOTE: trust only the most right side value + source_ip = headers.get("x-forwarded-for", "").split(", ")[-1] + else: + source_ip = request_context.get("identity", {}).get("sourceIp") + + path = event["path"] + http_method = event["httpMethod"] + + if event.get("multiValueQueryStringParameters"): + query_string = urllib.parse.urlencode( + event.get("multiValueQueryStringParameters", {}), doseq=True + ).encode() + elif event.get("queryStringParameters"): + query_string = urllib.parse.urlencode( + event.get("queryStringParameters", {}) + ).encode() + else: + query_string = b"" + + server_name = headers.get("host", "mangum") + if ":" not in server_name: + server_port = headers.get("x-forwarded-port", 80) + else: + server_name, server_port = server_name.split(":") # pragma: no cover + server = (server_name, int(server_port)) + client = (source_ip, 0) + + if not path: # pragma: no cover + path = "/" + elif self.api_gateway_base_path: + if path.startswith(self.api_gateway_base_path): + path = path[len(self.api_gateway_base_path) :] + + scope = { + "type": "http", + "http_version": "1.1", + "method": http_method, + "headers": [[k.encode(), v.encode()] for k, v in headers.items()], + "path": urllib.parse.unquote(path), + "raw_path": None, + "root_path": "", + "scheme": headers.get("x-forwarded-proto", "https"), + "query_string": query_string, + "server": server, + "client": client, + "asgi": {"version": "3.0"}, + "aws.event": event, + "aws.context": context, + } + + return scope