Skip to content

Commit b8337fa

Browse files
Merging from develop
2 parents b104888 + 7c9589c commit b8337fa

File tree

11 files changed

+652
-79
lines changed

11 files changed

+652
-79
lines changed

aws_lambda_powertools/event_handler/http_resolver.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import base64
55
import inspect
6+
import threading
67
import warnings
78
from typing import TYPE_CHECKING, Any, Callable
89
from urllib.parse import parse_qs
@@ -324,36 +325,65 @@ async def final_handler(app):
324325
return await next_handler(self)
325326

326327
def _wrap_middleware_async(self, middleware: Callable, next_handler: Callable) -> Callable:
327-
"""Wrap a middleware to work in async context."""
328+
"""Wrap a middleware to work in async context.
329+
330+
For sync middlewares, we split execution into pre/post phases around the
331+
call to next(). The sync middleware runs its pre-processing (e.g. request
332+
validation), then we intercept the next() call, await the async handler,
333+
and resume the middleware with the real response so post-processing
334+
(e.g. response validation) sees the actual data.
335+
"""
328336

329337
async def wrapped(app):
330-
# Create a next_middleware that the sync middleware can call
331-
def sync_next(app):
332-
# This will be called by sync middleware
333-
# We need to run the async next_handler
334-
loop = asyncio.get_event_loop()
335-
if loop.is_running():
336-
# We're in an async context, create a task
337-
future = asyncio.ensure_future(next_handler(app))
338-
# Store for later await
339-
app.context["_async_next_result"] = future
340-
return Response(status_code=200, body="") # Placeholder
341-
else: # pragma: no cover
342-
return loop.run_until_complete(next_handler(app))
343-
344-
# Check if middleware is async
345338
if inspect.iscoroutinefunction(middleware):
346-
result = await middleware(app, next_handler)
347-
else:
348-
# Sync middleware - need special handling
349-
result = middleware(app, sync_next)
339+
return await middleware(app, next_handler)
350340

351-
# Check if we stored an async result
352-
if "_async_next_result" in app.context:
353-
future = app.context.pop("_async_next_result")
354-
result = await future
341+
# We use an Event to coordinate: the sync middleware runs in a thread,
342+
# calls sync_next which signals us to resolve the async handler,
343+
# then waits for the real response.
344+
middleware_called_next = asyncio.Event()
345+
next_app_holder: list = []
346+
real_response_holder: list = []
347+
middleware_result_holder: list = []
348+
middleware_error_holder: list = []
355349

356-
return result
350+
def sync_next(app):
351+
next_app_holder.append(app)
352+
middleware_called_next.set()
353+
# Block this thread until the real response is available
354+
event = threading.Event()
355+
next_app_holder.append(event)
356+
event.wait()
357+
return real_response_holder[0]
358+
359+
def run_middleware():
360+
try:
361+
result = middleware(app, sync_next)
362+
middleware_result_holder.append(result)
363+
except Exception as e:
364+
middleware_error_holder.append(e)
365+
366+
thread = threading.Thread(target=run_middleware, daemon=True)
367+
thread.start()
368+
369+
# Wait for the middleware to call next()
370+
await middleware_called_next.wait()
371+
372+
# Now resolve the async next_handler
373+
real_response = await next_handler(next_app_holder[0])
374+
real_response_holder.append(real_response)
375+
376+
# Signal the thread that the response is ready
377+
threading_event = next_app_holder[1]
378+
threading_event.set()
379+
380+
# Wait for the middleware thread to finish
381+
thread.join()
382+
383+
if middleware_error_holder:
384+
raise middleware_error_holder[0]
385+
386+
return middleware_result_holder[0]
357387

358388
return wrapped
359389

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,17 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
9999
headers,
100100
)
101101

102+
# Process cookie values
103+
cookie_values, cookie_errors = _request_params_to_args(
104+
route.dependant.cookie_params,
105+
app.current_event.resolved_cookies_field,
106+
)
107+
102108
values.update(path_values)
103109
values.update(query_values)
104110
values.update(header_values)
105-
errors += path_errors + query_errors + header_errors
111+
values.update(cookie_values)
112+
errors += path_errors + query_errors + header_errors + cookie_errors
106113

107114
# Process the request body, if it exists
108115
if route.dependant.body_params:

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,79 @@ def alias(self, value: str | None = None):
658658
self._alias = value.lower()
659659

660660

661+
class Cookie(Param): # type: ignore[misc]
662+
"""
663+
A class used internally to represent a cookie parameter in a path operation.
664+
"""
665+
666+
in_ = ParamTypes.cookie
667+
668+
def __init__(
669+
self,
670+
default: Any = Undefined,
671+
*,
672+
default_factory: Callable[[], Any] | None = _Unset,
673+
annotation: Any | None = None,
674+
alias: str | None = None,
675+
alias_priority: int | None = _Unset,
676+
# MAINTENANCE: update when deprecating Pydantic v1, import these types
677+
# str | AliasPath | AliasChoices | None
678+
validation_alias: str | None = _Unset,
679+
serialization_alias: str | None = None,
680+
title: str | None = None,
681+
description: str | None = None,
682+
gt: float | None = None,
683+
ge: float | None = None,
684+
lt: float | None = None,
685+
le: float | None = None,
686+
min_length: int | None = None,
687+
max_length: int | None = None,
688+
pattern: str | None = None,
689+
discriminator: str | None = None,
690+
strict: bool | None = _Unset,
691+
multiple_of: float | None = _Unset,
692+
allow_inf_nan: bool | None = _Unset,
693+
max_digits: int | None = _Unset,
694+
decimal_places: int | None = _Unset,
695+
examples: list[Any] | None = None,
696+
openapi_examples: dict[str, Example] | None = None,
697+
deprecated: bool | None = None,
698+
include_in_schema: bool = True,
699+
json_schema_extra: dict[str, Any] | None = None,
700+
**extra: Any,
701+
):
702+
super().__init__(
703+
default=default,
704+
default_factory=default_factory,
705+
annotation=annotation,
706+
alias=alias,
707+
alias_priority=alias_priority,
708+
validation_alias=validation_alias,
709+
serialization_alias=serialization_alias,
710+
title=title,
711+
description=description,
712+
gt=gt,
713+
ge=ge,
714+
lt=lt,
715+
le=le,
716+
min_length=min_length,
717+
max_length=max_length,
718+
pattern=pattern,
719+
discriminator=discriminator,
720+
strict=strict,
721+
multiple_of=multiple_of,
722+
allow_inf_nan=allow_inf_nan,
723+
max_digits=max_digits,
724+
decimal_places=decimal_places,
725+
deprecated=deprecated,
726+
examples=examples,
727+
openapi_examples=openapi_examples,
728+
include_in_schema=include_in_schema,
729+
json_schema_extra=json_schema_extra,
730+
**extra,
731+
)
732+
733+
661734
class Body(FieldInfo): # type: ignore[misc]
662735
"""
663736
A class used internally to represent a body parameter in a path operation.

aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class APIGatewayAuthorizerResponse:
437437
- https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-lambda-authorizer-output.html
438438
"""
439439

440-
path_regex = r"^[/.a-zA-Z0-9-_\*]+$"
440+
path_regex = r"^[/.a-zA-Z0-9\-_\*\{\}\+]+$"
441441
"""The regular expression used to validate resource paths for the policy"""
442442

443443
def __init__(

aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,17 @@ def raw_query_string(self) -> str:
280280
def cookies(self) -> list[str]:
281281
return self.get("cookies") or []
282282

283+
@property
284+
def resolved_cookies_field(self) -> dict[str, str]:
285+
"""
286+
Parse cookies from the dedicated ``cookies`` field in API Gateway HTTP API v2 format.
287+
288+
The ``cookies`` field contains a list of strings like ``["session=abc", "theme=dark"]``.
289+
"""
290+
from aws_lambda_powertools.utilities.data_classes.common import _parse_cookie_string
291+
292+
return _parse_cookie_string("; ".join(self.cookies))
293+
283294
@property
284295
def request_context(self) -> RequestContextV2:
285296
return RequestContextV2(self["requestContext"])

aws_lambda_powertools/utilities/data_classes/common.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@
2929
)
3030

3131

32+
def _parse_cookie_string(cookie_string: str) -> dict[str, str]:
33+
"""Parse a cookie string (``key=value; key2=value2``) into a dict."""
34+
cookies: dict[str, str] = {}
35+
for segment in cookie_string.split(";"):
36+
stripped = segment.strip()
37+
if "=" in stripped:
38+
name, _, value = stripped.partition("=")
39+
cookies[name.strip()] = value.strip()
40+
return cookies
41+
42+
3243
class CaseInsensitiveDict(dict):
3344
"""Case insensitive dict implementation. Assumes string keys only."""
3445

@@ -203,6 +214,36 @@ def resolved_headers_field(self) -> dict[str, str]:
203214
"""
204215
return self.headers
205216

217+
@property
218+
def resolved_cookies_field(self) -> dict[str, str]:
219+
"""
220+
This property extracts cookies from the request as a dict of name-value pairs.
221+
222+
By default, cookies are parsed from the ``Cookie`` header.
223+
Uses ``self.headers`` (CaseInsensitiveDict) first for reliable case-insensitive
224+
lookup, then falls back to ``resolved_headers_field`` for proxies that only
225+
populate multi-value headers (e.g., ALB without single-value headers).
226+
Subclasses may override this for event formats that provide cookies
227+
in a dedicated field (e.g., API Gateway HTTP API v2).
228+
"""
229+
# Primary: self.headers is CaseInsensitiveDict — case-insensitive lookup
230+
cookie_value: str | list[str] = self.headers.get("cookie") or ""
231+
232+
# Fallback: resolved_headers_field covers ALB/REST v1 multi-value headers
233+
# where the event may not have a single-value 'headers' dict at all
234+
if not cookie_value:
235+
headers = self.resolved_headers_field or {}
236+
cookie_value = headers.get("cookie") or headers.get("Cookie") or ""
237+
238+
# Multi-value headers (ALB, REST v1) may return a list
239+
if isinstance(cookie_value, list):
240+
cookie_value = "; ".join(cookie_value)
241+
242+
if not cookie_value:
243+
return {}
244+
245+
return _parse_cookie_string(cookie_value)
246+
206247
@property
207248
def is_base64_encoded(self) -> bool | None:
208249
return self.get("isBase64Encoded")

0 commit comments

Comments
 (0)