Skip to content

Commit de0d230

Browse files
authored
fix: support generic dataclass (#525)
1 parent 7d76570 commit de0d230

File tree

3 files changed

+95
-2
lines changed

3 files changed

+95
-2
lines changed

dataclasses_json/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
_get_type_arg_param,
2525
_get_type_args, _is_counter,
2626
_NO_ARGS,
27-
_issubclass_safe, _is_tuple)
27+
_issubclass_safe, _is_tuple,
28+
_is_generic_dataclass)
2829

2930
Json = Union[dict, list, str, int, float, bool, None]
3031

@@ -269,8 +270,9 @@ def _is_supported_generic(type_):
269270
return False
270271
not_str = not _issubclass_safe(type_, str)
271272
is_enum = _issubclass_safe(type_, Enum)
273+
is_generic_dataclass = _is_generic_dataclass(type_)
272274
return (not_str and _is_collection(type_)) or _is_optional(
273-
type_) or is_union_type(type_) or is_enum
275+
type_) or is_union_type(type_) or is_enum or is_generic_dataclass
274276

275277

276278
def _decode_generic(type_, value, infer_missing):
@@ -308,6 +310,9 @@ def _decode_generic(type_, value, infer_missing):
308310
except (TypeError, AttributeError):
309311
pass
310312
res = materialize_type(xs)
313+
elif _is_generic_dataclass(type_):
314+
origin = _get_type_origin(type_)
315+
res = _decode_dataclass(origin, value, infer_missing)
311316
else: # Optional or Union
312317
_args = _get_type_args(type_)
313318
if _args is _NO_ARGS:

dataclasses_json/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
from datetime import datetime, timezone
44
from collections import Counter
5+
from dataclasses import is_dataclass # type: ignore
56
from typing import (Collection, Mapping, Optional, TypeVar, Any, Type, Tuple,
67
Union, cast)
78

@@ -164,6 +165,10 @@ def _is_nonstr_collection(type_):
164165
and not _issubclass_safe(type_, str))
165166

166167

168+
def _is_generic_dataclass(type_):
169+
return is_dataclass(_get_type_origin(type_))
170+
171+
167172
def _timestamp_to_dt_aware(timestamp: float):
168173
tz = datetime.now(timezone.utc).astimezone().tzinfo
169174
dt = datetime.fromtimestamp(timestamp, tz=tz)

tests/test_generic_dataclass.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from dataclasses import dataclass
2+
from datetime import datetime
3+
from typing import Generic, TypeVar
4+
5+
import pytest
6+
7+
from dataclasses_json import dataclass_json
8+
9+
S = TypeVar("S")
10+
T = TypeVar("T")
11+
12+
13+
@dataclass_json
14+
@dataclass
15+
class Bar:
16+
value: int
17+
18+
19+
@dataclass_json
20+
@dataclass
21+
class Foo(Generic[T]):
22+
bar: T
23+
24+
25+
@dataclass_json
26+
@dataclass
27+
class Baz(Generic[T]):
28+
foo: Foo[T]
29+
30+
31+
@pytest.mark.parametrize(
32+
"instance_of_t, decodes_successfully",
33+
[
34+
pytest.param(1, True, id="literal"),
35+
pytest.param([1], True, id="literal_list"),
36+
pytest.param({"a": 1}, True, id="map_of_literal"),
37+
pytest.param(datetime(2021, 1, 1), False, id="extended_type"),
38+
pytest.param(Bar(1), False, id="object"),
39+
]
40+
)
41+
def test_dataclass_with_generic_dataclass_field(instance_of_t, decodes_successfully):
42+
foo = Foo(bar=instance_of_t)
43+
baz = Baz(foo=foo)
44+
decoded = Baz[type(instance_of_t)].from_json(baz.to_json())
45+
assert decoded.foo == Foo.from_json(foo.to_json())
46+
if decodes_successfully:
47+
assert decoded == baz
48+
else:
49+
assert decoded != baz
50+
51+
52+
@dataclass_json
53+
@dataclass
54+
class Foo2(Generic[T, S]):
55+
bar1: T
56+
bar2: S
57+
58+
59+
@dataclass_json
60+
@dataclass
61+
class Baz2(Generic[T, S]):
62+
foo2: Foo2[T, S]
63+
64+
65+
@pytest.mark.parametrize(
66+
"instance_of_t, decodes_successfully",
67+
[
68+
pytest.param(1, True, id="literal"),
69+
pytest.param([1], True, id="literal_list"),
70+
pytest.param({"a": 1}, True, id="map_of_literal"),
71+
pytest.param(datetime(2021, 1, 1), False, id="extended_type"),
72+
pytest.param(Bar(1), False, id="object"),
73+
]
74+
)
75+
def test_dataclass_with_multiple_generic_dataclass_fields(instance_of_t, decodes_successfully):
76+
foo2 = Foo2(bar1=instance_of_t, bar2=instance_of_t)
77+
baz = Baz2(foo2=foo2)
78+
decoded = Baz2[type(instance_of_t), type(instance_of_t)].from_json(baz.to_json())
79+
assert decoded.foo2 == Foo2.from_json(foo2.to_json())
80+
if decodes_successfully:
81+
assert decoded == baz
82+
else:
83+
assert decoded != baz

0 commit comments

Comments
 (0)