Skip to content

Commit 38411f4

Browse files
committed
Added _is_sequence() and tests for new cases
1 parent 62297ba commit 38411f4

File tree

4 files changed

+100
-17
lines changed

4 files changed

+100
-17
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
fail-fast: false
1414
matrix:
15-
python-version: ["3.7", "3.8", "3.9", "3.10"]
15+
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
1616
os: [ubuntu-latest, macOS-latest]
1717
# Python 3.7 is not supported on Apple ARM64,
1818
# or the latest Ubuntu 2404

flask_pydantic/converters.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Type, Union
1+
import types
2+
from collections import deque
3+
from typing import Deque, FrozenSet, List, Sequence, Set, Tuple, Type, Union
24

35
try:
46
from typing import get_args, get_origin
@@ -10,15 +12,29 @@
1012
from werkzeug.datastructures import ImmutableMultiDict
1113

1214
V1OrV2BaseModel = Union[BaseModel, V1BaseModel]
15+
UnionType = getattr(types, "UnionType", Union)
1316

17+
sequence_types = {
18+
Sequence,
19+
List,
20+
list,
21+
Tuple,
22+
tuple,
23+
Set,
24+
set,
25+
FrozenSet,
26+
frozenset,
27+
Deque,
28+
deque,
29+
}
1430

15-
def _is_list(type_: Type) -> bool:
16-
origin = get_origin(type_)
17-
if origin is list:
18-
return True
19-
if origin is Union:
20-
return any(_is_list(t) for t in get_args(type_))
21-
return False
31+
32+
def _is_sequence(type_: Type) -> bool:
33+
origin = get_origin(type_) or type_
34+
if origin is Union or origin is UnionType:
35+
return any(_is_sequence(t) for t in get_args(type_))
36+
37+
return origin in sequence_types and origin not in (str, bytes)
2238

2339

2440
def convert_query_params(
@@ -38,7 +54,7 @@ def convert_query_params(
3854
key: value
3955
for key, value in query_params.to_dict(flat=False).items()
4056
if key in model.model_fields
41-
and _is_list(model.model_fields[key].annotation)
57+
and _is_sequence(model.model_fields[key].annotation)
4258
},
4359
}
4460
else:

tests/unit/test_core.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
2-
from typing import Any, List, NamedTuple, Optional, Type, Union
2+
import sys
3+
from typing import Any, List, NamedTuple, Optional, Tuple, Type, Union
34

45
import pytest
56
from flask import jsonify
@@ -15,11 +16,15 @@
1516
from ..util import assert_matches
1617

1718

19+
class EmptyModel(BaseModel):
20+
pass
21+
22+
1823
class ValidateParams(NamedTuple):
19-
body_model: Optional[Type[BaseModel]] = None
20-
query_model: Optional[Type[BaseModel]] = None
21-
form_model: Optional[Type[BaseModel]] = None
22-
response_model: Type[BaseModel] = None
24+
body_model: Type[BaseModel] = EmptyModel
25+
query_model: Type[BaseModel] = EmptyModel
26+
form_model: Type[BaseModel] = EmptyModel
27+
response_model: Type[BaseModel] = EmptyModel
2328
on_success_status: int = 200
2429
request_query: ImmutableMultiDict = ImmutableMultiDict({})
2530
request_body: Union[dict, List[dict]] = {}
@@ -50,7 +55,23 @@ class RequestBodyModel(BaseModel):
5055

5156
class FormModel(BaseModel):
5257
f1: int
53-
f2: str = None
58+
f2: Optional[str] = None
59+
60+
61+
class RequestBodyWithIterableModel(BaseModel):
62+
b1: List[str]
63+
b2: Tuple[str, int]
64+
b3: Optional[List[int]] = None
65+
b4: Union[Tuple[str, int], None] = None
66+
67+
68+
if sys.version_info >= (3, 10):
69+
# New Python(>=3.10) syntax tests
70+
class RequestBodyWithIterableModelPy310(BaseModel):
71+
b1: list[str]
72+
b2: tuple[str, int]
73+
b3: list[int] | None = None
74+
b4: tuple[str, int] | None = None
5475

5576

5677
class RequestBodyModelRoot(RootModel):
@@ -195,8 +216,54 @@ class RequestBodyModelRoot(RootModel):
195216
),
196217
id="invalid form param",
197218
),
219+
pytest.param(
220+
ValidateParams(
221+
body_model=RequestBodyWithIterableModel,
222+
request_body={
223+
"b1": ["str1", "str1"],
224+
"b2": ("str", 123),
225+
"b3": [1, 2, 3],
226+
"b4": ("str", 321),
227+
},
228+
expected_response_body={
229+
"b1": ["str1", "str1"],
230+
"b2": ("str", 123),
231+
"b3": [1, 2, 3],
232+
"b4": ("str", 321),
233+
},
234+
response_model=RequestBodyWithIterableModel,
235+
expected_status_code=200,
236+
),
237+
id="ASASD",
238+
),
198239
]
199240

241+
if sys.version_info >= (3, 10):
242+
validate_test_cases.extend(
243+
[
244+
pytest.param(
245+
ValidateParams(
246+
body_model=RequestBodyWithIterableModelPy310,
247+
request_body={
248+
"b1": ["str1", "str1"],
249+
"b2": ("str", 123),
250+
"b3": [1, 2, 3],
251+
"b4": ("str", 321),
252+
},
253+
expected_response_body={
254+
"b1": ["str1", "str1"],
255+
"b2": ("str", 123),
256+
"b3": [1, 2, 3],
257+
"b4": ("str", 321),
258+
},
259+
response_model=RequestBodyWithIterableModelPy310,
260+
expected_status_code=200,
261+
),
262+
id="ASASD",
263+
),
264+
]
265+
)
266+
200267

201268
class TestValidate:
202269
@pytest.mark.parametrize("parameters", validate_test_cases)

tests/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def assert_matches(expected: ExpectedType, actual: ActualType):
2121
assert set(expected.keys()) == set(actual.keys())
2222
for key, value in expected.items():
2323
assert_matches(value, actual[key])
24-
elif isinstance(expected, list):
24+
elif isinstance(expected, (list, tuple)):
2525
assert len(expected) == len(actual)
2626
for a, b in zip(expected, actual):
2727
assert_matches(a, b)

0 commit comments

Comments
 (0)