diff --git a/dataclasses_json/core.py b/dataclasses_json/core.py index 7da0d150..efb46ed3 100644 --- a/dataclasses_json/core.py +++ b/dataclasses_json/core.py @@ -17,7 +17,7 @@ Tuple, TypeVar, Type) from uuid import UUID -from typing_inspect import is_union_type # type: ignore +from typing_inspect import is_union_type, is_literal_type # type: ignore from dataclasses_json import cfg from dataclasses_json.utils import (_get_type_cons, _get_type_origin, @@ -358,7 +358,8 @@ def _decode_dict_keys(key_type, xs, infer_missing): # This is a special case for Python 3.7 and Python 3.8. # By some reason, "unbound" dicts are counted # as having key type parameter to be TypeVar('KT') - if key_type is None or key_type == Any or isinstance(key_type, TypeVar): + # Literal types are also passed through without any decoding. + if key_type is None or key_type == Any or isinstance(key_type, TypeVar) or is_literal_type(key_type): decode_function = key_type = (lambda x: x) # handle a nested python dict that has tuples for keys. E.g. for # Dict[Tuple[int], int], key_type will be typing.Tuple[int], but diff --git a/dataclasses_json/mm.py b/dataclasses_json/mm.py index 9cfacf1d..d7c0f28c 100644 --- a/dataclasses_json/mm.py +++ b/dataclasses_json/mm.py @@ -11,14 +11,14 @@ from uuid import UUID from enum import Enum -from typing_inspect import is_union_type # type: ignore +from typing_inspect import is_union_type, is_literal_type # type: ignore from marshmallow import fields, Schema, post_load # type: ignore from marshmallow.exceptions import ValidationError # type: ignore from dataclasses_json.core import (_is_supported_generic, _decode_dataclass, _ExtendedEncoder, _user_overrides_or_exts) -from dataclasses_json.utils import (_is_collection, _is_optional, +from dataclasses_json.utils import (_get_type_args, _is_collection, _is_optional, _issubclass_safe, _timestamp_to_dt_aware, _is_new_type, _get_type_origin, _handle_undefined_parameters_safe, @@ -130,6 +130,46 @@ def _deserialize(self, value, attr, data, **kwargs): return None if optional_list is None else tuple(optional_list) +class _LiteralField(fields.Field): + def __init__(self, literal_values, cls, field, *args, **kwargs): + """Create a new Literal field. + + Literals allow you to specify the set of valid _values_ for a field. The field + implementation validates against these values on deserialization. + + Example: + >>> @dataclass + ... class DataClassWithLiteral(DataClassJsonMixin): + ... read_mode: Literal["r", "w", "a"] + + Args: + literal_values: A sequence of possible values for the field. + cls: The dataclass that the field belongs to. + field: The field that the schema describes. + """ + self.literal_values = literal_values + self.cls = cls + self.field = field + super().__init__(*args, **kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + if self.allow_none and value is None: + return None + if value not in self.literal_values: + warnings.warn( + f'The value "{value}" is not one of the values of typing.Literal ' + f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' + f'Value will not be de-serialized properly.') + return super()._serialize(value, attr, obj, **kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + if value not in self.literal_values: + raise ValidationError( + f'Value "{value}" is not one in typing.Literal{self.literal_values} ' + f'(dataclass: {self.cls.__name__}, field: {self.field.name}).') + return super()._deserialize(value, attr, data, **kwargs) + + TYPES = { typing.Mapping: fields.Mapping, typing.MutableMapping: fields.Mapping, @@ -259,9 +299,14 @@ def inner(type_, options): f"`dataclass_json` decorator or mixin.") return fields.Field(**options) - origin = getattr(type_, '__origin__', type_) - args = [inner(a, {}) for a in getattr(type_, '__args__', []) if - a is not type(None)] + origin = _get_type_origin(type_) + + # Type arguments are typically types (e.g. int in list[int]) except for Literal + # types, where they are values. + if is_literal_type(type_): + args = [] + else: + args = [inner(a, {}) for a in _get_type_args(type_) if a is not type(None)] if type_ == Ellipsis: return type_ @@ -279,6 +324,10 @@ def inner(type_, options): if _issubclass_safe(origin, Enum): return fields.Enum(enum=origin, by_value=True, *args, **options) + if is_literal_type(type_): + literal_values = _get_type_args(type_) + return _LiteralField(literal_values, cls, field, **options) + if is_union_type(type_): union_types = [a for a in getattr(type_, '__args__', []) if a is not type(None)] diff --git a/tests/test_literals.py b/tests/test_literals.py new file mode 100644 index 00000000..cd47bdbb --- /dev/null +++ b/tests/test_literals.py @@ -0,0 +1,99 @@ +"""Test dataclasses_json handling of Literal types.""" +import sys +import pytest + +if sys.version_info < (3, 8): + pytest.skip("Literal types are only supported in Python 3.8+", allow_module_level=True) + +import json +from typing import Literal, Optional, List, Dict + +from dataclasses import dataclass + +from dataclasses_json import dataclass_json, DataClassJsonMixin +from marshmallow.exceptions import ValidationError # type: ignore + + +@dataclass_json +@dataclass +class DataClassWithLiteral(DataClassJsonMixin): + numeric_literals: Literal[0, 1] + string_literals: Literal["one", "two", "three"] + mixed_literals: Literal[0, "one", 2] + + +with_valid_literal_json = '{"numeric_literals": 0, "string_literals": "one", "mixed_literals": 2}' +with_valid_literal_data = DataClassWithLiteral(numeric_literals=0, string_literals="one", mixed_literals=2) +with_invalid_literal_json = '{"numeric_literals": 9, "string_literals": "four", "mixed_literals": []}' +with_invalid_literal_data = DataClassWithLiteral(numeric_literals=9, string_literals="four", mixed_literals=[]) # type: ignore + +@dataclass_json +@dataclass +class DataClassWithNestedLiteral(DataClassJsonMixin): + list_of_literals: List[Literal[0, 1]] + dict_of_literals: Dict[Literal["one", "two", "three"], Literal[0, 1]] + optional_literal: Optional[Literal[0, 1]] + +with_valid_nested_literal_json = '{"list_of_literals": [0, 1], "dict_of_literals": {"one": 0, "two": 1}, "optional_literal": 1}' +with_valid_nested_literal_data = DataClassWithNestedLiteral(list_of_literals=[0, 1], dict_of_literals={"one": 0, "two": 1}, optional_literal=1) +with_invalid_nested_literal_json = '{"list_of_literals": [0, 2], "dict_of_literals": {"one": 0, "four": 2}, "optional_literal": 2}' +with_invalid_nested_literal_data = DataClassWithNestedLiteral(list_of_literals=[0, 2], dict_of_literals={"one": 0, "four": 2}, optional_literal=2) # type: ignore + +class TestEncoder: + def test_valid_literal(self): + assert with_valid_literal_data.to_dict(encode_json=True) == json.loads(with_valid_literal_json) + + def test_invalid_literal(self): + assert with_invalid_literal_data.to_dict(encode_json=True) == json.loads(with_invalid_literal_json) + + def test_valid_nested_literal(self): + assert with_valid_nested_literal_data.to_dict(encode_json=True) == json.loads(with_valid_nested_literal_json) + + def test_invalid_nested_literal(self): + assert with_invalid_nested_literal_data.to_dict(encode_json=True) == json.loads(with_invalid_nested_literal_json) + + +class TestSchemaEncoder: + def test_valid_literal(self): + actual = DataClassWithLiteral.schema().dumps(with_valid_literal_data) + assert json.loads(actual) == json.loads(with_valid_literal_json) + + def test_invalid_literal(self): + actual = DataClassWithLiteral.schema().dumps(with_invalid_literal_data) + assert json.loads(actual) == json.loads(with_invalid_literal_json) + + def test_valid_nested_literal(self): + actual = DataClassWithNestedLiteral.schema().dumps(with_valid_nested_literal_data) + assert json.loads(actual) == json.loads(with_valid_nested_literal_json) + + def test_invalid_nested_literal(self): + actual = DataClassWithNestedLiteral.schema().dumps(with_invalid_nested_literal_data) + assert json.loads(actual) == json.loads(with_invalid_nested_literal_json) + +class TestDecoder: + def test_valid_literal(self): + actual = DataClassWithLiteral.from_json(with_valid_literal_json) + assert actual == with_valid_literal_data + + def test_invalid_literal(self): + expected = DataClassWithLiteral(numeric_literals=9, string_literals="four", mixed_literals=[]) # type: ignore + actual = DataClassWithLiteral.from_json(with_invalid_literal_json) + assert actual == expected + + +class TestSchemaDecoder: + def test_valid_literal(self): + actual = DataClassWithLiteral.schema().loads(with_valid_literal_json) + assert actual == with_valid_literal_data + + def test_invalid_literal(self): + with pytest.raises(ValidationError): + DataClassWithLiteral.schema().loads(with_invalid_literal_json) + + def test_valid_nested_literal(self): + actual = DataClassWithNestedLiteral.schema().loads(with_valid_nested_literal_json) + assert actual == with_valid_nested_literal_data + + def test_invalid_nested_literal(self): + with pytest.raises(ValidationError): + DataClassWithNestedLiteral.schema().loads(with_invalid_nested_literal_json)