Skip to content

Commit

Permalink
Replace awslambdaric-stubs with new types, refactor import style. (#235)
Browse files Browse the repository at this point in the history
* 🎨 Replace awslambdaric-stubs with new types, refactor import style.
  • Loading branch information
jordaneremieff authored Feb 12, 2022
1 parent 8505788 commit f6e2211
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 78 deletions.
26 changes: 10 additions & 16 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import logging
from contextlib import ExitStack
from typing import Any, Dict, TYPE_CHECKING

from .exceptions import ConfigurationError
from .handlers import AbstractHandler
from .protocols import HTTPCycle, LifespanCycle
from .types import ASGIApp


if TYPE_CHECKING: # pragma: no cover
from awslambdaric.lambda_context import LambdaContext
from mangum.exceptions import ConfigurationError
from mangum.handlers.abstract_handler import AbstractHandler
from mangum.protocols import HTTPCycle, LifespanCycle
from mangum.types import ASGIApp, LambdaEvent, LambdaContext


DEFAULT_TEXT_MIME_TYPES = [
Expand Down Expand Up @@ -53,20 +49,18 @@ def __init__(
"Invalid argument supplied for `lifespan`. Choices are: auto|on|off"
)

def __call__(self, event: Dict[str, Any], context: "LambdaContext") -> dict:
def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict:
logger.debug("Event received.")

with ExitStack() as stack:
if self.lifespan != "off":
lifespan_cycle = LifespanCycle(self.app, self.lifespan)
stack.enter_context(lifespan_cycle)

handler = AbstractHandler.from_trigger(
event, context, self.api_gateway_base_path
)
request = handler.request

http_cycle = HTTPCycle(request)
response = http_cycle(self.app, handler.body)
handler = AbstractHandler.from_trigger(
event, context, self.api_gateway_base_path
)
http_cycle = HTTPCycle(handler.request)
response = http_cycle(self.app, handler.body)

return handler.transform_response(response)
13 changes: 0 additions & 13 deletions mangum/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
from .abstract_handler import AbstractHandler
from .aws_alb import AwsAlb
from .aws_api_gateway import AwsApiGateway
from .aws_cf_lambda_at_edge import AwsCfLambdaAtEdge
from .aws_http_gateway import AwsHttpGateway

__all__ = [
"AbstractHandler",
"AwsAlb",
"AwsApiGateway",
"AwsCfLambdaAtEdge",
"AwsHttpGateway",
]
23 changes: 10 additions & 13 deletions mangum/handlers/abstract_handler.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import base64
from abc import ABCMeta, abstractmethod
from typing import Dict, Any, TYPE_CHECKING, Tuple, List
from typing import Dict, Any, Tuple, List

from ..types import Response, Request

if TYPE_CHECKING: # pragma: no cover
from awslambdaric.lambda_context import LambdaContext
from mangum.types import Response, Request, LambdaEvent, LambdaContext


class AbstractHandler(metaclass=ABCMeta):
def __init__(
self,
trigger_event: Dict[str, Any],
trigger_context: "LambdaContext",
trigger_event: LambdaEvent,
trigger_context: LambdaContext,
):
self.trigger_event = trigger_event
self.trigger_context = trigger_context
Expand Down Expand Up @@ -40,8 +37,8 @@ def transform_response(self, response: Response) -> Dict[str, Any]:

@staticmethod
def from_trigger(
trigger_event: Dict[str, Any],
trigger_context: "LambdaContext",
trigger_event: LambdaEvent,
trigger_context: LambdaContext,
api_gateway_base_path: str = "/",
) -> "AbstractHandler":
"""
Expand All @@ -55,20 +52,20 @@ def from_trigger(
"requestContext" in trigger_event
and "elb" in trigger_event["requestContext"]
):
from . import AwsAlb
from mangum.handlers.aws_alb import AwsAlb

return AwsAlb(trigger_event, trigger_context)
if (
"Records" in trigger_event
and len(trigger_event["Records"]) > 0
and "cf" in trigger_event["Records"][0]
):
from . import AwsCfLambdaAtEdge
from mangum.handlers.aws_cf_lambda_at_edge import AwsCfLambdaAtEdge

return AwsCfLambdaAtEdge(trigger_event, trigger_context)

if "version" in trigger_event and "requestContext" in trigger_event:
from . import AwsHttpGateway
from mangum.handlers.aws_http_gateway import AwsHttpGateway

return AwsHttpGateway(
trigger_event,
Expand All @@ -77,7 +74,7 @@ def from_trigger(
)

if "resource" in trigger_event:
from . import AwsApiGateway
from mangum.handlers.aws_api_gateway import AwsApiGateway

return AwsApiGateway(
trigger_event,
Expand Down
6 changes: 2 additions & 4 deletions mangum/handlers/aws_alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from typing import Any, Dict, Generator, List, Tuple
from itertools import islice

from mangum.types import QueryParams

from .abstract_handler import AbstractHandler
from .. import Response, Request
from mangum.types import Response, Request, QueryParams
from mangum.handlers.abstract_handler import AbstractHandler


def all_casings(input_string: str) -> Generator[str, None, None]:
Expand Down
16 changes: 5 additions & 11 deletions mangum/handlers/aws_api_gateway.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import base64
from urllib.parse import urlencode, unquote
from typing import Dict, Any, TYPE_CHECKING
from typing import Dict, Any

from mangum.types import QueryParams

from .abstract_handler import AbstractHandler
from .. import Response, Request


if TYPE_CHECKING: # pragma: no cover
from awslambdaric.lambda_context import LambdaContext
from mangum.handlers.abstract_handler import AbstractHandler
from mangum.types import Response, Request, LambdaEvent, LambdaContext, QueryParams


class AwsApiGateway(AbstractHandler):
def __init__(
self,
trigger_event: Dict[str, Any],
trigger_context: "LambdaContext",
trigger_event: LambdaEvent,
trigger_context: LambdaContext,
api_gateway_base_path: str,
):
super().__init__(trigger_event, trigger_context)
Expand Down
4 changes: 2 additions & 2 deletions mangum/handlers/aws_cf_lambda_at_edge.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import base64
from typing import Dict, Any, List

from .abstract_handler import AbstractHandler
from .. import Response, Request
from mangum.handlers.abstract_handler import AbstractHandler
from mangum.types import Response, Request


class AwsCfLambdaAtEdge(AbstractHandler):
Expand Down
4 changes: 2 additions & 2 deletions mangum/handlers/aws_http_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import urllib.parse
from typing import Dict, Any, List, Tuple

from . import AwsApiGateway
from .. import Response, Request
from mangum.handlers.aws_api_gateway import AwsApiGateway
from mangum.types import Response, Request


class AwsHttpGateway(AwsApiGateway):
Expand Down
4 changes: 1 addition & 3 deletions mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def __call__(self, app: ASGIApp, initial_body: bytes) -> Response:
asgi_task = self.loop.create_task(asgi_instance)
self.loop.run_until_complete(asgi_task)

if self.response is None:
# Something really bad happened and we puked before we could get a
# response out
if self.response is None: # pragma: nocover
self.response = Response(
status=500,
body=b"Internal Server Error",
Expand Down
83 changes: 78 additions & 5 deletions mangum/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
MutableMapping,
Awaitable,
Callable,
TYPE_CHECKING,
)
from typing_extensions import Protocol, TypeAlias

Expand All @@ -21,15 +20,89 @@
Send: TypeAlias = Callable[[Message], Awaitable[None]]


if TYPE_CHECKING: # pragma: no cover
from awslambdaric.lambda_context import LambdaContext


class ASGIApp(Protocol):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
... # pragma: no cover


LambdaEvent = Dict[str, Any]


class LambdaCognitoIdentity(Protocol):
"""Information about the Amazon Cognito identity that authorized the request.
**cognito_identity_id** - The authenticated Amazon Cognito identity.
**cognito_identity_pool_id** - The Amazon Cognito identity pool that authorized the
invocation.
"""

cognito_identity_id: str
cognito_identity_pool_id: str


class LambdaMobileClient(Protocol):
"""Mobile client information for the application and the device.
**installation_id** - A unique identifier for an installation instance of an
application.
**app_title** - The title of the application. For example, "My App".
**app_version_code** - The version of the application. For example, "V2.0".
**app_version_name** - The version code for the application. For example, 3.
**app_package_name** - The name of the package. For example, "com.example.my_app".
"""

installation_id: str
app_title: str
app_version_name: str
app_version_code: str
app_package_name: str


class LambdaMobileClientContext(Protocol):
"""Information about client application and device when invoked via AWS Mobile SDK.
**client** - A dict of name-value pairs that describe the mobile client application.
**custom** - A dict of custom values set by the mobile client application.
**env** - A dict of environment information provided by the AWS SDK.
"""

client: LambdaMobileClient
custom: Dict[str, Any]
env: Dict[str, Any]


class LambdaContext(Protocol):
"""The context object passed to the handler function.
**function_name** - The name of the Lambda function.
**function_version** - The version of the function.
**invoked_function_arn** - The Amazon Resource Name (ARN) that's used to invoke the
function. Indicates if the invoker specified a version number or alias.
**memory_limit_in_mb** - The amount of memory that's allocated for the function.
**aws_request_id** - The identifier of the invocation request.
**log_group_name** - The log group for the function.
**log_stream_name** - The log stream for the function instance.
**identity** - (mobile apps) Information about the Amazon Cognito identity that
authorized the request.
**client_context** - (mobile apps) Client context that's provided to Lambda by the
client application.
"""

function_name: str
function_version: str
invoked_function_arn: str
memory_limit_in_mb: int
aws_request_id: str
log_group_name: str
log_stream_name: str
identity: Optional[LambdaCognitoIdentity]
client_context: Optional[LambdaMobileClientContext]

def get_remaining_time_in_millis(self) -> int:
"""Returns the number of milliseconds left before the execution times out."""
... # pragma: no cover


@dataclass
class BaseRequest:
"""
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ quart; python_version >= '3.7'
mypy
brotli
brotli-asgi
awslambdaric-stubs
types-dataclasses; python_version < '3.7' # For mypy
# Docs
mkdocs
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_abstract_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from mangum.handlers import AbstractHandler
from mangum.handlers.abstract_handler import AbstractHandler


def test_abstract_handler_unkown_event():
Expand Down
5 changes: 3 additions & 2 deletions tests/handlers/test_aws_alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501
"""
from typing import Dict, List, Optional

import pytest

from mangum import Mangum
from mangum.handlers import AwsAlb
from typing import Dict, List, Optional
from mangum.handlers.aws_alb import AwsAlb


def get_mock_aws_alb_event(
Expand Down
6 changes: 3 additions & 3 deletions tests/handlers/test_aws_api_gateway.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest

import urllib.parse

import pytest

from mangum import Mangum
from mangum.handlers import AwsApiGateway
from mangum.handlers.aws_api_gateway import AwsApiGateway


def get_mock_aws_api_gateway_event(
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_aws_cf_lambda_at_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from mangum import Mangum
from mangum.handlers import AwsCfLambdaAtEdge
from mangum.handlers.aws_cf_lambda_at_edge import AwsCfLambdaAtEdge


def mock_lambda_at_edge_event(
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_aws_http_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from mangum import Mangum
from mangum.handlers import AwsHttpGateway
from mangum.handlers.aws_http_gateway import AwsHttpGateway


def get_mock_aws_http_gateway_event_v1(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import json

import pytest

import brotli
from brotli_asgi import BrotliMiddleware

from starlette.applications import Starlette
from starlette.middleware.gzip import GZipMiddleware
from starlette.responses import PlainTextResponse

from mangum import Mangum


Expand Down
1 change: 1 addition & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging

import pytest

from starlette.applications import Starlette
from starlette.responses import PlainTextResponse

Expand Down

0 comments on commit f6e2211

Please sign in to comment.