From 2d89014ad7b5921a94eff1d1b56e2a5837035e74 Mon Sep 17 00:00:00 2001 From: Vichym Date: Mon, 12 May 2025 15:00:21 -0700 Subject: [PATCH] Bug fix binary decode in request for local start api --- samcli/local/apigw/event_constructor.py | 48 ++- .../local/apigw/test_event_constructor.py | 357 +++++++++++------- 2 files changed, 257 insertions(+), 148 deletions(-) diff --git a/samcli/local/apigw/event_constructor.py b/samcli/local/apigw/event_constructor.py index b9e07ecb8d..8ed0033e00 100644 --- a/samcli/local/apigw/event_constructor.py +++ b/samcli/local/apigw/event_constructor.py @@ -54,12 +54,14 @@ def construct_v1_event( if is_base_64: LOG.debug("Incoming Request seems to be binary. Base64 encoding the request data before sending to Lambda.") request_data = base64.b64encode(request_data) - - if request_data: - # Flask does not parse/decode the request data. We should do it ourselves - # Note(xinhol): here we change request_data's type from bytes to str and confused mypy - # We might want to consider to use a new variable here. - request_data = request_data.decode("utf-8") + request_data = request_data.decode("utf-8") if request_data else "" + elif isinstance(request_data, bytes): + try: + request_data = request_data.decode("utf-8") + except UnicodeDecodeError: + LOG.debug("Failed to decode request data as UTF-8, falling back to base64 encoding") + request_data = base64.b64encode(request_data).decode("utf-8") + is_base_64 = True query_string_dict, multi_value_query_string_dict = _query_string_params(flask_request) @@ -131,10 +133,14 @@ def construct_v2_event_http( if is_base_64: LOG.debug("Incoming Request seems to be binary. Base64 encoding the request data before sending to Lambda.") request_data = base64.b64encode(request_data) - - if request_data is not None: - # Flask does not parse/decode the request data. We should do it ourselves - request_data = request_data.decode("utf-8") + request_data = request_data.decode("utf-8") if request_data else "" + elif isinstance(request_data, bytes): + try: + request_data = request_data.decode("utf-8") + except UnicodeDecodeError: + LOG.debug("Failed to decode request data as UTF-8, falling back to base64 encoding") + request_data = base64.b64encode(request_data).decode("utf-8") + is_base_64 = True query_string_dict = _query_string_params_v_2_0(flask_request) @@ -328,4 +334,24 @@ def _should_base64_encode(binary_types, request_mimetype): True if the data should be encoded to Base64 otherwise False """ - return request_mimetype in binary_types or "*/*" in binary_types + # 1. Handle multipart form data with potential binary content + if request_mimetype and request_mimetype.startswith("multipart/form-data"): + return True + + # 2. Check for wildcard match - this should work even if request_mimetype is None + if "*/*" in binary_types: + return True + + # 3. If no MIME type, we can't do specific matching + if not request_mimetype: + return False + + # 4. Strip parameters from MIME type (e.g., "text/plain; charset=utf-8" -> "text/plain") + # and check for exact match + base_mimetype = request_mimetype.split(";")[0].strip() + if base_mimetype in binary_types: + return True + + # 6. Check for type/* wildcard (e.g., "image/*" matches "image/png") + type_prefix = base_mimetype.split("/")[0] + "/*" + return type_prefix in binary_types diff --git a/tests/unit/local/apigw/test_event_constructor.py b/tests/unit/local/apigw/test_event_constructor.py index 6035a6efc9..0f084d9f8a 100644 --- a/tests/unit/local/apigw/test_event_constructor.py +++ b/tests/unit/local/apigw/test_event_constructor.py @@ -1,24 +1,127 @@ import base64 from datetime import datetime -import json from time import time +from typing import Any from unittest import TestCase -from unittest.mock import Mock, patch -from parameterized import parameterized, param +from unittest.mock import Mock +from parameterized import parameterized_class from samcli.local.apigw.event_constructor import ( _event_headers, _event_http_headers, _query_string_params, _query_string_params_v_2_0, - _should_base64_encode, construct_v1_event, construct_v2_event_http, ) from samcli.local.apigw.local_apigw_service import LocalApigwService +input_scenarios = [ + # No Data + (None, None, [], None, False), + # Standard text formats + ("application/json", b'{"key": "value"}', [], '{"key": "value"}', False), + ("text/plain", b"Hello, world!", [], "Hello, world!", False), + ("text/html", b"Hello", [], "Hello", False), + ("application/xml", b"value", [], "value", False), + # Binary formats with matching binary_types + ( + "image/gif", + b"\x47\x49\x46\x38\x39\x61 binary data", + ["image/gif"], + base64.b64encode(b"\x47\x49\x46\x38\x39\x61 binary data").decode("ascii"), + True, + ), + ( + "image/png", + b"\x89PNG\r\n\x1a\n binary data", + ["image/png"], + base64.b64encode(b"\x89PNG\r\n\x1a\n binary data").decode("ascii"), + True, + ), + ( + "image/jpeg", + b"\xff\xd8\xff\xe0 binary data", + ["image/jpeg"], + base64.b64encode(b"\xff\xd8\xff\xe0 binary data").decode("ascii"), + True, + ), + ( + "application/pdf", + b"%PDF-1.5 binary data", + ["application/pdf"], + base64.b64encode(b"%PDF-1.5 binary data").decode("ascii"), + True, + ), + ( + "application/octet-stream", + b"\x00\x01\x02\x03 binary data", + ["application/octet-stream"], + base64.b64encode(b"\x00\x01\x02\x03 binary data").decode("ascii"), + True, + ), + # Binary format without matching binary_type (should be treated as text) + # This might fail UTF-8 decoding and fall back to base64, so we need to handle both cases + ( + "image/png", + b"\x89PNG\r\n\x1a\n binary data", + [], + base64.b64encode(b"\x89PNG\r\n\x1a\n binary data").decode("ascii"), + True, + ), + # Text format with invalid UTF-8 (should fall back to base64) + ("text/plain", b"\xff\xfe invalid utf-8", [], base64.b64encode(b"\xff\xfe invalid utf-8").decode("ascii"), True), + # Multipart form data (should always be base64 encoded) + ( + "multipart/form-data", + b'--boundary\r\nContent-Disposition: form-data; name="file"\r\n\r\nbinary data\r\n--boundary--', + [], + base64.b64encode( + b'--boundary\r\nContent-Disposition: form-data; name="file"\r\n\r\nbinary data\r\n--boundary--' + ).decode("ascii"), + True, + ), + # Special cases + ( + None, + b"some binary data", + ["*/*"], + base64.b64encode(b"some binary data").decode("utf-8"), + True, + ), # No MIME type but */* in binary_types + ("application/json", None, [], None, False), # No data + ( + "application/x-www-form-urlencoded", + b"param1=value1¶m2=value2", + [], + "param1=value1¶m2=value2", + False, + ), # Form data + # Edge cases + ("text/csv", b"id,name\n1,test", ["text/csv"], base64.b64encode(b"id,name\n1,test").decode("ascii"), True), + ( + "application/zip", + b"PK\x03\x04 zip data", + ["*/*"], + base64.b64encode(b"PK\x03\x04 zip data").decode("ascii"), + True, + ), + ("application/javascript", b"function test() { return true; }", [], "function test() { return true; }", False), +] + + +@parameterized_class( + ("request_mimetype", "request_get_data_return", "binary_types", "expected_body", "expected_is_base64"), + input_scenarios, +) class TestService_construct_event(TestCase): + request_mimetype: str + request_get_data_return: bytes + binary_types: list + expected_body: Any + expected_is_base64: bool + def setUp(self): self.request_mock = Mock() self.request_mock.endpoint = "endpoint" @@ -26,7 +129,8 @@ def setUp(self): self.request_mock.method = "GET" self.request_mock.remote_addr = "190.0.0.0" self.request_mock.host = "190.0.0.1" - self.request_mock.get_data.return_value = b"DATA!!!!" + self.request_mock.get_data.return_value = self.request_get_data_return + self.request_mock.mimetype = self.request_mimetype query_param_args_mock = Mock() query_param_args_mock.lists.return_value = {"query": ["params"]}.items() self.request_mock.args = query_param_args_mock @@ -40,65 +144,77 @@ def setUp(self): environ_dict = {"SERVER_PROTOCOL": "HTTP/1.1"} self.request_mock.environ = environ_dict - expected = ( - '{"body": "DATA!!!!", "httpMethod": "GET", ' - '"multiValueQueryStringParameters": {"query": ["params"]}, ' - '"queryStringParameters": {"query": "params"}, "resource": ' - '"endpoint", "requestContext": {"httpMethod": "GET", "requestId": ' - '"c6af9ac6-7b61-11e6-9a41-93e8deadbeef", "path": "endpoint", "extendedRequestId": null, ' - '"resourceId": "123456", "apiId": "1234567890", "stage": null, "resourcePath": "endpoint", ' - '"identity": {"accountId": null, "apiKey": null, "userArn": null, ' - '"cognitoAuthenticationProvider": null, "cognitoIdentityPoolId": null, "userAgent": ' - '"Custom User Agent String", "caller": null, "cognitoAuthenticationType": null, "sourceIp": ' - '"190.0.0.0", "user": null}, "accountId": "123456789012", "domainName": "190.0.0.1", ' - '"protocol": "HTTP/1.1"}, "headers": {"Content-Type": ' - '"application/json", "X-Test": "Value", "X-Forwarded-Port": "3000", "X-Forwarded-Proto": "http"}, ' - '"multiValueHeaders": {"Content-Type": ["application/json"], "X-Test": ["Value"], ' - '"X-Forwarded-Port": ["3000"], "X-Forwarded-Proto": ["http"]}, ' - '"stageVariables": null, "path": "path", "pathParameters": {"path": "params"}, ' - '"isBase64Encoded": false}' - ) - - self.expected_dict = json.loads(expected) + self.expected_dict = { + "body": self.expected_body, + "httpMethod": "GET", + "multiValueQueryStringParameters": {"query": ["params"]}, + "queryStringParameters": {"query": "params"}, + "resource": "endpoint", + "requestContext": { + "httpMethod": "GET", + "requestId": "c6af9ac6-7b61-11e6-9a41-93e8deadbeef", + "path": "endpoint", + "extendedRequestId": None, + "resourceId": "123456", + "apiId": "1234567890", + "stage": None, + "resourcePath": "endpoint", + "identity": { + "accountId": None, + "apiKey": None, + "userArn": None, + "cognitoAuthenticationProvider": None, + "cognitoIdentityPoolId": None, + "userAgent": "Custom User Agent String", + "caller": None, + "cognitoAuthenticationType": None, + "sourceIp": "190.0.0.0", + "user": None, + }, + "accountId": "123456789012", + "domainName": "190.0.0.1", + "protocol": "HTTP/1.1", + }, + "headers": { + "Content-Type": "application/json", + "X-Test": "Value", + "X-Forwarded-Port": "3000", + "X-Forwarded-Proto": "http", + }, + "multiValueHeaders": { + "Content-Type": ["application/json"], + "X-Test": ["Value"], + "X-Forwarded-Port": ["3000"], + "X-Forwarded-Proto": ["http"], + }, + "stageVariables": None, + "path": "path", + "pathParameters": {"path": "params"}, + "isBase64Encoded": self.expected_is_base64, + } + + def test_construct_event(self): + actual_event = construct_v1_event(self.request_mock, 3000, self.binary_types) + self.maxDiff = None - def validate_request_context_and_remove_request_time_data(self, event_json): - request_time = event_json["requestContext"].pop("requestTime", None) - request_time_epoch = event_json["requestContext"].pop("requestTimeEpoch", None) + # Remove dynamic fields from requestContext + request_id = actual_event["requestContext"].pop("requestId", None) + request_time = actual_event["requestContext"].pop("requestTime", None) + request_time_epoch = actual_event["requestContext"].pop("requestTimeEpoch", None) + self.assertEqual(len(request_id), 36) self.assertIsInstance(request_time, str) + parsed_request_time = datetime.strptime(request_time, "%d/%b/%Y:%H:%M:%S +0000") self.assertIsInstance(parsed_request_time, datetime) self.assertIsInstance(request_time_epoch, int) - def test_construct_event_with_data(self): - actual_event_json = construct_v1_event(self.request_mock, 3000, binary_types=[]) - self.validate_request_context_and_remove_request_time_data(actual_event_json) - - self.assertEqual(actual_event_json["body"], self.expected_dict["body"]) - - def test_construct_event_no_data(self): - self.request_mock.get_data.return_value = None - - actual_event_json = construct_v1_event(self.request_mock, 3000, binary_types=[]) - self.validate_request_context_and_remove_request_time_data(actual_event_json) - - self.assertEqual(actual_event_json["body"], None) - - @patch("samcli.local.apigw.event_constructor._should_base64_encode") - def test_construct_event_with_binary_data(self, should_base64_encode_patch): - should_base64_encode_patch.return_value = True + self.expected_dict["requestContext"].pop("requestId", None) + self.expected_dict["requestContext"].pop("requestTime", None) + self.expected_dict["requestContext"].pop("requestTimeEpoch", None) - binary_body = b"011000100110100101101110011000010111001001111001" # binary in binary - base64_body = base64.b64encode(binary_body).decode("utf-8") - - self.request_mock.get_data.return_value = binary_body - - actual_event_json = construct_v1_event(self.request_mock, 3000, binary_types=[]) - self.validate_request_context_and_remove_request_time_data(actual_event_json) - - self.assertEqual(actual_event_json["body"], base64_body) - self.assertEqual(actual_event_json["isBase64Encoded"], True) + self.assertEqual(actual_event, self.expected_dict) def test_event_headers_with_empty_list(self): request_mock = Mock() @@ -181,14 +297,24 @@ def test_query_string_params_v_2_0_with_param_value_being_non_empty_list(self): self.assertEqual(actual_query_string, {"param": "a,b"}) +@parameterized_class( + ("request_mimetype", "request_get_data_return", "binary_types", "expected_body", "expected_is_base64"), + input_scenarios, +) class TestService_construct_event_http(TestCase): + request_mimetype: str + request_get_data_return: bytes + binary_types: list + expected_body: Any + expected_is_base64: bool + def setUp(self): self.request_mock = Mock() self.request_mock.endpoint = "endpoint" self.request_mock.method = "GET" self.request_mock.path = "/endpoint" - self.request_mock.get_data.return_value = b"DATA!!!!" - self.request_mock.mimetype = "application/json" + self.request_mock.get_data.return_value = self.request_get_data_return + self.request_mock.mimetype = self.request_mimetype query_param_args_mock = Mock() query_param_args_mock.lists.return_value = {"query": ["param1", "param2"]}.items() self.request_mock.args = query_param_args_mock @@ -208,75 +334,72 @@ def setUp(self): self.request_time_epoch = int(time()) self.request_time = datetime.utcnow().strftime("%d/%b/%Y:%H:%M:%S +0000") - expected = f""" - {{ + self.expected_dict = { "version": "2.0", "routeKey": "GET /endpoint", "rawPath": "/endpoint", "rawQueryString": "query=params", "cookies": ["cookie1=test", "cookie2=test"], - "headers": {{ + "headers": { "Content-Type": "application/json", "X-Test": "Value", "X-Forwarded-Proto": "http", - "X-Forwarded-Port": "3000" - }}, - "queryStringParameters": {{"query": "param1,param2"}}, - "requestContext": {{ + "X-Forwarded-Port": "3000", + }, + "queryStringParameters": {"query": "param1,param2"}, + "requestContext": { "accountId": "123456789012", "apiId": "1234567890", "domainName": "localhost", "domainPrefix": "localhost", - "http": {{ + "http": { "method": "GET", "path": "/endpoint", "protocol": "HTTP/1.1", "sourceIp": "190.0.0.0", - "userAgent": "Custom User Agent String" - }}, + "userAgent": "Custom User Agent String", + }, "requestId": "", "routeKey": "GET /endpoint", "stage": "$default", - "time": \"{self.request_time}\", - "timeEpoch": {self.request_time_epoch} - }}, - "body": "DATA!!!!", - "pathParameters": {{"path": "params"}}, - "stageVariables": null, - "isBase64Encoded": false - }} - """ - - self.expected_dict = json.loads(expected) + "time": self.request_time, + "timeEpoch": self.request_time_epoch, + }, + "body": self.expected_body, + "pathParameters": {"path": "params"}, + "stageVariables": None, + "isBase64Encoded": self.expected_is_base64, + } def test_construct_event_with_data(self): - actual_event_dict = construct_v2_event_http( + actual_event = construct_v2_event_http( self.request_mock, 3000, - binary_types=[], + binary_types=self.binary_types, route_key="GET /endpoint", request_time_epoch=self.request_time_epoch, request_time=self.request_time, ) - self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) - actual_event_dict["requestContext"]["requestId"] = "" - self.assertEqual(actual_event_dict, self.expected_dict) + self.maxDiff = None - def test_construct_event_no_data(self): - self.request_mock.get_data.return_value = None - self.expected_dict["body"] = None + # Remove dynamic fields from requestContext + request_id = actual_event["requestContext"].pop("requestId", None) + request_time = actual_event["requestContext"].pop("time", None) + request_time_epoch = actual_event["requestContext"].pop("timeEpoch", None) - actual_event_dict = construct_v2_event_http( - self.request_mock, - 3000, - binary_types=[], - route_key="GET /endpoint", - request_time_epoch=self.request_time_epoch, - request_time=self.request_time, - ) - self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) - actual_event_dict["requestContext"]["requestId"] = "" - self.assertEqual(actual_event_dict, self.expected_dict) + self.assertEqual(len(request_id), 36) + self.assertIsInstance(request_time, str) + + parsed_request_time = datetime.strptime(request_time, "%d/%b/%Y:%H:%M:%S +0000") + self.assertIsInstance(parsed_request_time, datetime) + + self.assertIsInstance(request_time_epoch, int) + + self.expected_dict["requestContext"].pop("requestId", None) + self.expected_dict["requestContext"].pop("time", None) + self.expected_dict["requestContext"].pop("timeEpoch", None) + + self.assertEqual(actual_event, self.expected_dict) def test_v2_route_key(self): route_key = LocalApigwService._v2_route_key("GET", "/path", False) @@ -286,30 +409,6 @@ def test_v2_default_route_key(self): route_key = LocalApigwService._v2_route_key("GET", "/path", True) self.assertEqual(route_key, "$default") - @patch("samcli.local.apigw.event_constructor._should_base64_encode") - def test_construct_event_with_binary_data(self, should_base64_encode_patch): - should_base64_encode_patch.return_value = True - - binary_body = b"011000100110100101101110011000010111001001111001" # binary in binary - base64_body = base64.b64encode(binary_body).decode("utf-8") - - self.request_mock.get_data.return_value = binary_body - self.expected_dict["body"] = base64_body - self.expected_dict["isBase64Encoded"] = True - self.maxDiff = None - - actual_event_dict = construct_v2_event_http( - self.request_mock, - 3000, - binary_types=[], - route_key="GET /endpoint", - request_time_epoch=self.request_time_epoch, - request_time=self.request_time, - ) - self.assertEqual(len(actual_event_dict["requestContext"]["requestId"]), 36) - actual_event_dict["requestContext"]["requestId"] = "" - self.assertEqual(actual_event_dict, self.expected_dict) - def test_event_headers_with_empty_list(self): request_mock = Mock() headers_mock = Mock() @@ -339,19 +438,3 @@ def test_event_headers_with_non_empty_list(self): "X-Forwarded-Port": "3000", }, ) - - -class TestService_should_base64_encode(TestCase): - @parameterized.expand( - [ - param("Mimeyype is in binary types", ["image/gif"], "image/gif"), - param("Mimetype defined and binary types has */*", ["*/*"], "image/gif"), - param("*/* is in binary types with no mimetype defined", ["*/*"], None), - ] - ) - def test_should_base64_encode_returns_true(self, test_case_name, binary_types, mimetype): - self.assertTrue(_should_base64_encode(binary_types, mimetype)) - - @parameterized.expand([param("Mimetype is not in binary types", ["image/gif"], "application/octet-stream")]) - def test_should_base64_encode_returns_false(self, test_case_name, binary_types, mimetype): - self.assertFalse(_should_base64_encode(binary_types, mimetype))