Skip to content

Commit

Permalink
- Added support for dataclasses nested in a type.
Browse files Browse the repository at this point in the history
- Fixed add_dataclass_arguments not forwarding sub_configs parameter.
  • Loading branch information
mauvilsa committed Apr 14, 2023
1 parent 17fcf0a commit 927ec02
Show file tree
Hide file tree
Showing 13 changed files with 167 additions and 81 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ v4.21.0 (2023-04-??)

Added
^^^^^
- Support for dataclasses nested in a type.
- Support for pydantic models and attr defines similar to dataclasses.

Fixed
Expand All @@ -27,6 +28,7 @@ Fixed
<https://github.com/Lightning-AI/lightning/issues/17254>`__).
- ``dataclass`` from pydantic not working (`#100 (comment)
<https://github.com/omni-us/jsonargparse/issues/100#issuecomment-1408413796>`__).
- ``add_dataclass_arguments`` not forwarding ``sub_configs`` parameter.

Changed
^^^^^^^
Expand Down
13 changes: 6 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -430,19 +430,18 @@ Some notes about this support are:
config files and environment variables, tuples and sets are represented as an
array.

- ``dataclasses`` are supported as a type but only for pure data classes and not
nested in a type. By pure it is meant that the class only inherits from data
classes. Not a mixture of normal classes and data classes. Data classes as
fields of other data classes is supported. Pydantic's ``dataclass`` decorator
and ``BaseModel`` classes, and attrs' ``define`` decorator are supported
like standard dataclasses. Though, this support is currently experimental.

- To set a value to ``None`` it is required to use ``null`` since this is how
json/yaml defines it. To avoid confusion in the help, ``NoneType`` is
displayed as ``null``. For example a function argument with type and default
``Optional[str] = None`` would be shown in the help as ``type: Union[str,
null], default: null``.

- ``dataclasses`` are supported even when nested. Final classes, attrs'
``define`` decorator, and pydantic's ``dataclass`` decorator and ``BaseModel``
classes are supported and behave like standard dataclasses. If a dataclass
inherits from a normal class, the type is considered a subclass instead of a
dataclass, see :ref:`sub-classes`.

- Normal classes can be used as a type, which are specified with a dict
containing ``class_path`` and optionally ``init_args``.
:py:meth:`.ArgumentParser.instantiate_classes` can be used to instantiate all
Expand Down
34 changes: 34 additions & 0 deletions jsonargparse/_common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import dataclasses
import inspect
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional, Union

from .namespace import Namespace
from .optionals import attrs_support, import_attrs, import_pydantic, pydantic_support
from .type_checking import ArgumentParser

parent_parser: ContextVar['ArgumentParser'] = ContextVar('parent_parser')
Expand Down Expand Up @@ -33,3 +36,34 @@ def parser_context(**kwargs):
finally:
for context_var, token in context_var_tokens:
context_var.reset(token)


def is_subclass(cls, class_or_tuple) -> bool:
"""Extension of issubclass that supports non-class arguments."""
try:
return inspect.isclass(cls) and issubclass(cls, class_or_tuple)
except TypeError:
return False


def is_final_class(cls) -> bool:
"""Checks whether a class is final, i.e. decorated with ``typing.final``."""
return getattr(cls, '__final__', False)


def is_dataclass_like(cls) -> bool:
if not inspect.isclass(cls):
return False
if is_final_class(cls):
return True
classes = [c for c in inspect.getmro(cls) if c != object]
all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes)
if not all_dataclasses and pydantic_support:
pydantic = import_pydantic('is_dataclass_like')
classes = [c for c in classes if c != pydantic.utils.Representation]
all_dataclasses = all(is_subclass(c, pydantic.BaseModel) for c in classes)
if not all_dataclasses and attrs_support:
attrs = import_attrs('is_dataclass_like')
if attrs.has(cls):
return True
return all_dataclasses
3 changes: 1 addition & 2 deletions jsonargparse/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from contextvars import ContextVar
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from ._common import parser_context
from ._common import is_subclass, parser_context
from .loaders_dumpers import get_loader_exceptions, load_value
from .namespace import Namespace, split_key, split_key_root
from .optionals import FilesCompleterMethod, get_config_read_mode
Expand All @@ -25,7 +25,6 @@
get_typehint_origin,
import_object,
indent_text,
is_subclass,
iter_to_set_str,
parse_value_or_config,
)
Expand Down
16 changes: 5 additions & 11 deletions jsonargparse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Union,
)

from ._common import lenient_check, parser_context
from ._common import is_dataclass_like, is_subclass, lenient_check, parser_context
from .actions import (
ActionConfigFile,
ActionParser,
Expand Down Expand Up @@ -71,16 +71,14 @@
import_jsonnet,
)
from .parameter_resolvers import UnknownDefault
from .signatures import SignatureArguments, is_pure_dataclass
from .signatures import SignatureArguments
from .typehints import ActionTypeHint, is_subclass_spec
from .typing import is_final_class
from .util import (
Path,
argument_error,
change_to_path_dir,
get_private_kwargs,
identity,
is_subclass,
return_parser_if_captured,
)

Expand Down Expand Up @@ -118,14 +116,10 @@ def add_argument(self, *args, enable_path: bool = False, **kwargs):
if is_subclass(kwargs['action'], ActionConfigFile) and any(isinstance(a, ActionConfigFile) for a in self._actions):
raise ValueError('A parser is only allowed to have a single ActionConfigFile argument.')
if 'type' in kwargs:
if is_final_class(kwargs['type']) or is_pure_dataclass(kwargs['type']):
if is_dataclass_like(kwargs['type']):
theclass = kwargs.pop('type')
nested_key = re.sub('^--', '', args[0])
if is_final_class(theclass):
kwargs.pop('help', None)
self.add_class_arguments(theclass, nested_key, **kwargs)
else:
self.add_dataclass_arguments(theclass, nested_key, **kwargs)
self.add_dataclass_arguments(theclass, nested_key, **kwargs)
return _find_action(parser, nested_key)
if ActionTypeHint.is_supported_typehint(kwargs['type']):
args = ActionTypeHint.prepare_add_argument(
Expand Down Expand Up @@ -1109,7 +1103,7 @@ def instantiate_classes(
components: List[Union[ActionTypeHint, _ActionConfigLoad, _ArgumentGroup]] = []
for action in filter_default_actions(self._actions):
if isinstance(action, ActionTypeHint) or \
(isinstance(action, _ActionConfigLoad) and is_pure_dataclass(action.basetype)):
(isinstance(action, _ActionConfigLoad) and is_dataclass_like(action.basetype)):
components.append(action)

if instantiate_groups:
Expand Down
2 changes: 1 addition & 1 deletion jsonargparse/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class ActionEnum:

def __init__(self, **kwargs):
if 'enum' in kwargs:
from .util import is_subclass
from ._common import is_subclass
if not is_subclass(kwargs['enum'], Enum):
raise ValueError('Expected enum to be an subclass of Enum.')
self._type = kwargs['enum']
Expand Down
5 changes: 2 additions & 3 deletions jsonargparse/parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from ._common import is_dataclass_like, is_subclass
from ._stubs_resolver import get_stub_types
from .optionals import parse_docs
from .util import (
ClassFromFunctionBase,
LoggerProperty,
get_import_path,
is_subclass,
iter_to_set_str,
parse_logger,
unique,
Expand Down Expand Up @@ -325,8 +325,7 @@ def get_kwargs_pop_or_get_parameter(node, component, parent, doc_params, log_deb


def is_param_subclass_instance_default(param: ParamData) -> bool:
from .signatures import is_pure_dataclass
if is_pure_dataclass(type(param.default)):
if is_dataclass_like(type(param.default)):
return False
from .typehints import ActionTypeHint, get_subclass_types
class_types = get_subclass_types(param.annotation)
Expand Down
54 changes: 17 additions & 37 deletions jsonargparse/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,16 @@
from contextlib import suppress
from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union

from ._common import is_dataclass_like, is_subclass
from .actions import _ActionConfigLoad
from .optionals import (
attrs_support,
get_doc_short_description,
import_attrs,
import_pydantic,
pydantic_support,
)
from .optionals import get_doc_short_description, import_pydantic, pydantic_support
from .parameter_resolvers import (
ParamData,
get_parameter_origins,
get_signature_parameters,
)
from .typehints import ActionTypeHint, LazyInitBaseClass, is_optional
from .typing import is_final_class
from .util import LoggerProperty, get_import_path, is_subclass, iter_to_set_str
from .util import LoggerProperty, get_import_path, iter_to_set_str

__all__ = [
'compose_dataclasses',
Expand Down Expand Up @@ -93,7 +87,7 @@ def add_class_arguments(
if default:
skip = skip or set()
prefix = nested_key+'.' if nested_key else ''
defaults = default.lazy_get_init_data().as_dict()
defaults = default.lazy_get_init_args()
if defaults:
defaults = {prefix+k: v for k, v in defaults.items() if k not in skip}
self.set_defaults(**defaults) # type: ignore
Expand Down Expand Up @@ -317,8 +311,7 @@ def _add_signature_parameter(
elif not as_positional:
kwargs['required'] = True
is_subclass_typehint = False
is_final_class_typehint = is_final_class(annotation)
is_pure_dataclass_typehint = is_pure_dataclass(annotation)
is_dataclass_like_typehint = is_dataclass_like(annotation)
dest = (nested_key+'.' if nested_key else '') + name
args = [dest if is_required and as_positional else '--'+dest]
if param.origin:
Expand All @@ -332,8 +325,7 @@ def _add_signature_parameter(
)
if annotation in {str, int, float, bool} or \
is_subclass(annotation, (str, int, float)) or \
is_final_class_typehint or \
is_pure_dataclass_typehint:
is_dataclass_like_typehint:
kwargs['type'] = annotation
elif annotation != inspect_empty:
try:
Expand All @@ -360,7 +352,7 @@ def _add_signature_parameter(
'sub_configs': sub_configs,
'instantiate': instantiate,
}
if is_final_class_typehint or is_pure_dataclass_typehint:
if is_dataclass_like_typehint:
kwargs.update(sub_add_kwargs)
action = group.add_argument(*args, **kwargs)
action.sub_add_kwargs = sub_add_kwargs
Expand Down Expand Up @@ -401,8 +393,8 @@ def add_dataclass_arguments(
ValueError: When not given a dataclass.
ValueError: When default is not instance of or kwargs for theclass.
"""
if not is_pure_dataclass(theclass):
raise ValueError(f'Expected "theclass" argument to be a pure dataclass, given {theclass}')
if not is_dataclass_like(theclass):
raise ValueError(f'Expected "theclass" argument to be a dataclass-like, given {theclass}')

doc_group = get_doc_short_description(theclass, logger=self.logger)
for key in ['help', 'title']:
Expand All @@ -420,6 +412,7 @@ def add_dataclass_arguments(
defaults = dataclass_to_dict(default)

added_args: List[str] = []
param_kwargs = {k: v for k, v in kwargs.items() if k == 'sub_configs'}
for param in get_signature_parameters(theclass, None, logger=self.logger):
self._add_signature_parameter(
group,
Expand All @@ -428,6 +421,7 @@ def add_dataclass_arguments(
added_args,
fail_untyped=fail_untyped,
default=defaults.get(param.name, inspect_empty),
**param_kwargs,
)

return added_args
Expand Down Expand Up @@ -467,8 +461,8 @@ def add_subclass_arguments(
Raises:
ValueError: When given an invalid base class.
"""
if is_final_class(baseclass):
raise ValueError("Not allowed for classes that are final.")
if is_dataclass_like(baseclass):
raise ValueError("Not allowed for dataclass-like classes.")
if type(baseclass) is not tuple:
baseclass = (baseclass,) # type: ignore
if not all(inspect.isclass(c) for c in baseclass):
Expand Down Expand Up @@ -550,32 +544,18 @@ def is_factory_class(value):
return value.__class__ == dataclasses._HAS_DEFAULT_FACTORY_CLASS


def is_pure_dataclass(value):
if not inspect.isclass(value):
return False
classes = [c for c in inspect.getmro(value) if c != object]
all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes)
if not all_dataclasses and pydantic_support:
pydantic = import_pydantic('is_pure_dataclass')
classes = [c for c in classes if c != pydantic.utils.Representation]
all_dataclasses = all(is_subclass(c, pydantic.BaseModel) for c in classes)
if not all_dataclasses and attrs_support:
attrs = import_attrs('is_pure_dataclass')
if attrs.has(value):
return True
return all_dataclasses


def dataclass_to_dict(value):
def dataclass_to_dict(value) -> dict:
if pydantic_support:
pydantic = import_pydantic('dataclass_to_dict')
if isinstance(value, pydantic.BaseModel):
return value.dict()
if isinstance(value, LazyInitBaseClass):
return value.lazy_get_init_data().as_dict()
return dataclasses.asdict(value)


def compose_dataclasses(*args):
"""Returns a pure dataclass inheriting all given dataclasses and properly handling __post_init__."""
"""Returns a dataclass inheriting all given dataclasses and properly handling __post_init__."""

@dataclasses.dataclass
class ComposedDataclass(*args):
Expand Down
Loading

0 comments on commit 927ec02

Please sign in to comment.