diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index f1f38b399a9..8e3ce1e7e4d 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1575,6 +1575,7 @@ def __init__( strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, response_validation_error_http_code: HTTPStatus | int | None = None, + deserializer: Callable[[str], dict] | None = None, ): """ Parameters @@ -1596,6 +1597,9 @@ def __init__( Enables validation of the request body against the route schema, by default False. response_validation_error_http_code Sets the returned status code if response is not validated. enable_validation must be True. + deserializer: Callable[[str], dict], optional + function to deserialize `str`, `bytes`, `bytearray` containing a JSON document to a Python `dict`, + by default json.loads """ self._proxy_type = proxy_type self._dynamic_routes: list[Route] = [] @@ -1621,6 +1625,7 @@ def __init__( # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) + self._deserializer = deserializer if self._enable_validation: from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware @@ -2431,24 +2436,24 @@ def _to_proxy_event(self, event: dict) -> BaseProxyEvent: # noqa: PLR0911 # ig """Convert the event dict to the corresponding data class""" if self._proxy_type == ProxyEventType.APIGatewayProxyEvent: logger.debug("Converting event to API Gateway REST API contract") - return APIGatewayProxyEvent(event) + return APIGatewayProxyEvent(event, self._deserializer) if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2: logger.debug("Converting event to API Gateway HTTP API contract") - return APIGatewayProxyEventV2(event) + return APIGatewayProxyEventV2(event, self._deserializer) if self._proxy_type == ProxyEventType.BedrockAgentEvent: logger.debug("Converting event to Bedrock Agent contract") - return BedrockAgentEvent(event) + return BedrockAgentEvent(event, self._deserializer) if self._proxy_type == ProxyEventType.LambdaFunctionUrlEvent: logger.debug("Converting event to Lambda Function URL contract") - return LambdaFunctionUrlEvent(event) + return LambdaFunctionUrlEvent(event, self._deserializer) if self._proxy_type == ProxyEventType.VPCLatticeEvent: logger.debug("Converting event to VPC Lattice contract") - return VPCLatticeEvent(event) + return VPCLatticeEvent(event, self._deserializer) if self._proxy_type == ProxyEventType.VPCLatticeEventV2: logger.debug("Converting event to VPC LatticeV2 contract") - return VPCLatticeEventV2(event) + return VPCLatticeEventV2(event, self._deserializer) logger.debug("Converting event to ALB contract") - return ALBEvent(event) + return ALBEvent(event, self._deserializer) def _resolve(self) -> ResponseBuilder: """Resolves the response or return the not found response""" @@ -2865,6 +2870,7 @@ def __init__( strip_prefixes: list[str | Pattern] | None = None, enable_validation: bool = False, response_validation_error_http_code: HTTPStatus | int | None = None, + deserializer: Callable[[str], dict] | None = None, ): """Amazon API Gateway REST and HTTP API v1 payload resolver""" super().__init__( @@ -2875,6 +2881,7 @@ def __init__( strip_prefixes, enable_validation, response_validation_error_http_code, + deserializer, ) def _get_base_path(self) -> str: diff --git a/tests/functional/event_handler/_pydantic/test_api_gateway.py b/tests/functional/event_handler/_pydantic/test_api_gateway.py index ce3fd89e864..bbd13082841 100644 --- a/tests/functional/event_handler/_pydantic/test_api_gateway.py +++ b/tests/functional/event_handler/_pydantic/test_api_gateway.py @@ -1,5 +1,9 @@ from __future__ import annotations +import json +from decimal import Decimal +from functools import partial + from pydantic import BaseModel from aws_lambda_powertools.event_handler import content_types @@ -80,3 +84,24 @@ def get_lambda(param: int): ... assert result["statusCode"] == 422 assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] assert "missing" in result["body"] + + +def test_api_gateway_resolver_numeric_value(): + # GIVEN a basic API Gateway resolver + app = ApiGatewayResolver(deserializer=partial(json.loads, parse_float=Decimal)) + + @app.post("/my/path") + def test_handler(): + return app.current_event.json_body + + # WHEN calling the event handler + event = {} + event.update(LOAD_GW_EVENT) + event["body"] = '{"amount": 2.2999999999999998}' + event["httpMethod"] = "POST" + + result = app(event, {}) + # THEN process event correctly + assert result["statusCode"] == 200 + assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] + assert result["body"] == '{"amount":"2.2999999999999998"}'