Skip to content

Commit 587d2fa

Browse files
small changes
1 parent 5cc265e commit 587d2fa

File tree

7 files changed

+70
-56
lines changed

7 files changed

+70
-56
lines changed

aws_lambda_powertools/event_handler/openapi/compat.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from collections import deque
55
from collections.abc import Mapping, Sequence
66
from copy import copy
7-
8-
# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different
9-
# versions of a module, so we need to ignore errors here.
107
from dataclasses import dataclass, is_dataclass
118
from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Set, Tuple, Union
129

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ def _has_discriminator(field_info: FieldInfo) -> bool:
10421042
return hasattr(field_info, "discriminator") and field_info.discriminator is not None
10431043

10441044

1045-
def _handle_discriminator_with_body(
1045+
def _handle_discriminator_with_param(
10461046
annotations: list[FieldInfo],
10471047
annotation: Any,
10481048
) -> tuple[FieldInfo | None, Any, bool]:
@@ -1111,10 +1111,10 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
11111111

11121112
# Determine which annotation to use
11131113
powertools_annotation: FieldInfo | None = None
1114-
has_discriminator_with_body = False
1114+
has_discriminator_with_param = False
11151115

11161116
if len(powertools_annotations) == 2:
1117-
powertools_annotation, type_annotation, has_discriminator_with_body = _handle_discriminator_with_body(
1117+
powertools_annotation, type_annotation, has_discriminator_with_param = _handle_discriminator_with_param(
11181118
powertools_annotations,
11191119
annotation,
11201120
)
@@ -1126,7 +1126,7 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
11261126
# Process the annotation if it exists
11271127
field_info: FieldInfo | None = None
11281128
if isinstance(powertools_annotation, FieldInfo):
1129-
field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_body)
1129+
field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_param)
11301130
_set_field_default(field_info, value, is_path_param)
11311131

11321132
# Preserve full annotated type for discriminated unions

docs/core/event_handler/api_gateway.md

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,17 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou
428428
--8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json"
429429
```
430430

431+
##### Discriminated unions
432+
433+
You can use Pydantic's `Field(discriminator="...")` with union types to create discriminated unions (also known as tagged unions). This allows the Event Handler to automatically determine which model to use based on a discriminator field in the request body.
434+
435+
```python hl_lines="3 4 8 31 36" title="discriminated_unions.py"
436+
--8<-- "examples/event_handler_rest/src/discriminated_unions.py"
437+
```
438+
439+
1. `Field(discriminator="action")` tells Pydantic to use the `action` field to determine which model to instantiate
440+
2. `Body()` annotation tells the Event Handler to parse the request body using the discriminated union
441+
431442
#### Validating responses
432443

433444
You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`.
@@ -568,23 +579,6 @@ You can use the `Form` type to tell the Event Handler that a parameter expects f
568579
--8<-- "examples/event_handler_rest/src/working_with_form_data.py"
569580
```
570581

571-
#### Discriminated unions
572-
573-
!!! info "You must set `enable_validation=True` to use discriminated unions via type annotation."
574-
575-
You can use Pydantic's `Field(discriminator="...")` with union types to create discriminated unions (also known as tagged unions). This allows the Event Handler to automatically determine which model to use based on a discriminator field in the request body.
576-
577-
In the following example, we define two action types (`FooAction` and `BarAction`) that share a common discriminator field `action`. The Event Handler will automatically parse the request body and instantiate the correct model based on the `action` field value:
578-
579-
```python hl_lines="3 4 8 31 36" title="discriminated_unions.py"
580-
--8<-- "examples/event_handler_rest/src/discriminated_unions.py"
581-
```
582-
583-
1. `Field(discriminator="action")` tells Pydantic to use the `action` field to determine which model to instantiate
584-
2. `Body()` annotation tells the Event Handler to parse the request body using the discriminated union
585-
586-
When you send a request with `{"action": "foo", "foo_data": "example"}`, the Event Handler will automatically create a `FooAction` instance. Similarly, `{"action": "bar", "bar_data": 42}` will create a `BarAction` instance.
587-
588582
#### Supported types for response serialization
589583

590584
With data validation enabled, we natively support serializing the following data types to JSON:
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from __future__ import annotations
2+
3+
from typing import Annotated, Literal
4+
5+
from pydantic import BaseModel, Field
6+
7+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
8+
from aws_lambda_powertools.event_handler.openapi.params import Body
9+
10+
app = APIGatewayRestResolver(enable_validation=True)
11+
app.enable_swagger()
12+
13+
14+
class FooAction(BaseModel):
15+
action: Literal["foo"]
16+
foo_data: str
17+
18+
19+
class BarAction(BaseModel):
20+
action: Literal["bar"]
21+
bar_data: int
22+
23+
24+
Action = Annotated[FooAction | BarAction, Field(discriminator="action")]
25+
26+
27+
@app.post("/data_validation_with_fields")
28+
def create_action(action: Annotated[Action, Body(discriminator="action")]):
29+
return {"message": "Powertools e2e API"}
30+
31+
32+
def lambda_handler(event, context):
33+
print(event)
34+
return app.resolve(event, context)

tests/e2e/event_handler/infrastructure.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def create_resources(self):
2424
functions["OpenapiHandler"],
2525
functions["OpenapiHandlerWithPep563"],
2626
functions["DataValidationAndMiddleware"],
27+
functions["DataValidationWithFields"],
2728
],
2829
)
2930
self._create_api_gateway_http(function=functions["ApiGatewayHttpHandler"])
@@ -105,6 +106,9 @@ def _create_api_gateway_rest(self, function: list[Function]):
105106
openapi_schema = apigw.root.add_resource("data_validation_middleware")
106107
openapi_schema.add_method("GET", apigwv1.LambdaIntegration(function[3], proxy=True))
107108

109+
openapi_schema = apigw.root.add_resource("data_validation_with_fields")
110+
openapi_schema.add_method("POST", apigwv1.LambdaIntegration(function[4], proxy=True))
111+
108112
CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url)
109113

110114
def _create_lambda_function_url(self, function: Function):

tests/e2e/event_handler/test_openapi.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,20 @@ def test_get_openapi_validation_and_middleware(apigw_rest_endpoint):
5959
)
6060

6161
assert response.status_code == 202
62+
63+
64+
def test_openapi_with_fields_discriminator(apigw_rest_endpoint):
65+
# GIVEN
66+
url = f"{apigw_rest_endpoint}data_validation_with_fields"
67+
68+
# WHEN
69+
response = data_fetcher.get_http_response(
70+
Request(
71+
method="POST",
72+
url=url,
73+
json={"action": "foo", "foo_data": "foo data working"},
74+
),
75+
)
76+
77+
assert "Powertools e2e API" in response.text
78+
assert response.status_code == 200

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,35 +2030,3 @@ def create_action(action: Annotated[action_type, Body()]):
20302030

20312031
result = app(gw_event, {})
20322032
assert result["statusCode"] == 422
2033-
2034-
2035-
def test_field_other_features_still_work(gw_event):
2036-
"""Test that other Field features still work after discriminator fix"""
2037-
app = APIGatewayRestResolver(enable_validation=True)
2038-
2039-
class UserInput(BaseModel):
2040-
name: Annotated[str, Field(min_length=2, max_length=50, description="User name")]
2041-
age: Annotated[int, Field(ge=18, le=120, description="User age")]
2042-
email: Annotated[str, Field(pattern=r".+@.+\..+", description="User email")]
2043-
2044-
@app.post("/users")
2045-
def create_user(user: UserInput):
2046-
return {"created": user.model_dump()}
2047-
2048-
gw_event["path"] = "/users"
2049-
gw_event["httpMethod"] = "POST"
2050-
gw_event["headers"]["content-type"] = "application/json"
2051-
gw_event["body"] = '{"name": "John", "age": 25, "email": "[email protected]"}'
2052-
2053-
result = app(gw_event, {})
2054-
assert result["statusCode"] == 200
2055-
2056-
response_body = json.loads(result["body"])
2057-
assert response_body["created"]["name"] == "John"
2058-
assert response_body["created"]["age"] == 25
2059-
assert response_body["created"]["email"] == "[email protected]"
2060-
2061-
gw_event["body"] = '{"name": "John", "age": 16, "email": "[email protected]"}'
2062-
2063-
result = app(gw_event, {})
2064-
assert result["statusCode"] == 422

0 commit comments

Comments
 (0)