Skip to content

Commit

Permalink
Extract ASGI scope creation into function (#162)
Browse files Browse the repository at this point in the history
* πŸ— simplify magnum call function

extract scope and request body creation into a dedicated functions

* πŸ“ add documentation to create scope function
  • Loading branch information
ediskandarov authored Feb 28, 2021
1 parent 70f0750 commit 18574b8
Showing 1 changed file with 89 additions and 76 deletions.
165 changes: 89 additions & 76 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, InitVar
from contextlib import ExitStack

from mangum.types import ASGIApp
from mangum.types import ASGIApp, Scope
from mangum.protocols.lifespan import LifespanCycle
from mangum.protocols.http import HTTPCycle
from mangum.exceptions import ConfigurationError
Expand Down Expand Up @@ -91,89 +91,102 @@ def __call__(self, event: dict, context: "LambdaContext") -> dict:
)
stack.enter_context(lifespan_cycle)

request_context = event["requestContext"]

if event.get("multiValueHeaders"):
headers = {
k.lower(): ", ".join(v) if isinstance(v, list) else ""
for k, v in event.get("multiValueHeaders", {}).items()
}
elif event.get("headers"):
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
else:
headers = {}

# API Gateway v2
if event.get("version") == "2.0":
source_ip = request_context["http"]["sourceIp"]
path = request_context["http"]["path"]
http_method = request_context["http"]["method"]
query_string = event.get("rawQueryString", "").encode()

if event.get("cookies"):
headers["cookie"] = "; ".join(event.get("cookies", []))

# API Gateway v1 / ELB
else:
if "elb" in request_context:
# NOTE: trust only the most right side value
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
else:
source_ip = request_context.get("identity", {}).get("sourceIp")

path = event["path"]
http_method = event["httpMethod"]

if event.get("multiValueQueryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("multiValueQueryStringParameters", {}), doseq=True
).encode()
elif event.get("queryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("queryStringParameters", {})
).encode()
else:
query_string = b""

server_name = headers.get("host", "mangum")
if ":" not in server_name:
server_port = headers.get("x-forwarded-port", 80)
else:
server_name, server_port = server_name.split(":") # pragma: no cover
server = (server_name, int(server_port))
client = (source_ip, 0)

if not path: # pragma: no cover
path = "/"
elif self.api_gateway_base_path:
if path.startswith(self.api_gateway_base_path):
path = path[len(self.api_gateway_base_path) :]

scope = {
"type": "http",
"http_version": "1.1",
"method": http_method,
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
"path": urllib.parse.unquote(path),
"raw_path": None,
"root_path": "",
"scheme": headers.get("x-forwarded-proto", "https"),
"query_string": query_string,
"server": server,
"client": client,
"asgi": {"version": "3.0"},
"aws.event": event,
"aws.context": context,
}

is_binary = event.get("isBase64Encoded", False)
initial_body = event.get("body") or b""
if is_binary:
initial_body = base64.b64decode(initial_body)
elif not isinstance(initial_body, bytes):
initial_body = initial_body.encode()

scope = self._create_scope(event, context)
http_cycle = HTTPCycle(scope, text_mime_types=self.text_mime_types)
response = http_cycle(self.app, initial_body)

return response

def _create_scope(self, event: dict, context: "LambdaContext") -> Scope:
"""Creates a scope object according to ASGI specification from a Lambda Event.
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
The event comes from various sources: AWS ALB, AWS API Gateway of different
versions and configurations(multivalue header, etc).
Thus, some heuristics is applied to guess an event type.
"""
request_context = event["requestContext"]

if event.get("multiValueHeaders"):
headers = {
k.lower(): ", ".join(v) if isinstance(v, list) else ""
for k, v in event.get("multiValueHeaders", {}).items()
}
elif event.get("headers"):
headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
else:
headers = {}

# API Gateway v2
if event.get("version") == "2.0":
source_ip = request_context["http"]["sourceIp"]
path = request_context["http"]["path"]
http_method = request_context["http"]["method"]
query_string = event.get("rawQueryString", "").encode()

if event.get("cookies"):
headers["cookie"] = "; ".join(event.get("cookies", []))

# API Gateway v1 / ELB
else:
if "elb" in request_context:
# NOTE: trust only the most right side value
source_ip = headers.get("x-forwarded-for", "").split(", ")[-1]
else:
source_ip = request_context.get("identity", {}).get("sourceIp")

path = event["path"]
http_method = event["httpMethod"]

if event.get("multiValueQueryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("multiValueQueryStringParameters", {}), doseq=True
).encode()
elif event.get("queryStringParameters"):
query_string = urllib.parse.urlencode(
event.get("queryStringParameters", {})
).encode()
else:
query_string = b""

server_name = headers.get("host", "mangum")
if ":" not in server_name:
server_port = headers.get("x-forwarded-port", 80)
else:
server_name, server_port = server_name.split(":") # pragma: no cover
server = (server_name, int(server_port))
client = (source_ip, 0)

if not path: # pragma: no cover
path = "/"
elif self.api_gateway_base_path:
if path.startswith(self.api_gateway_base_path):
path = path[len(self.api_gateway_base_path) :]

scope = {
"type": "http",
"http_version": "1.1",
"method": http_method,
"headers": [[k.encode(), v.encode()] for k, v in headers.items()],
"path": urllib.parse.unquote(path),
"raw_path": None,
"root_path": "",
"scheme": headers.get("x-forwarded-proto", "https"),
"query_string": query_string,
"server": server,
"client": client,
"asgi": {"version": "3.0"},
"aws.event": event,
"aws.context": context,
}

return scope

0 comments on commit 18574b8

Please sign in to comment.