Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for non-stdlib dataclass-like types #480

Merged
merged 44 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
2b212a5
Add new tests
nkrishnaswami Mar 22, 2024
863b6ee
Add new tests
nkrishnaswami Mar 22, 2024
1423c82
Fix new testcases
nkrishnaswami Mar 22, 2024
87c8493
Fix Pydantic 2 dataclasses init field detection for `dataclasses.Fiel…
nkrishnaswami Mar 28, 2024
542e73f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
60e65b1
change "should_init_field" to "is_init_field" and make its return val…
nkrishnaswami Mar 28, 2024
5f07506
Fix redundant assingment
nkrishnaswami Mar 28, 2024
0d87871
Remove `print` call
nkrishnaswami Mar 28, 2024
ed07c23
Remove unused variable assignment in test
nkrishnaswami Mar 28, 2024
fcde2c2
Fix CI error
nkrishnaswami Mar 28, 2024
a6441b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
40623b2
Make FieldInfo init field check more robust
nkrishnaswami Mar 28, 2024
af3042c
Make it clear default is assigned before use
nkrishnaswami Mar 28, 2024
8cc32d3
Look through aliases for adding class and method arguments
nkrishnaswami Mar 28, 2024
0d84421
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
721b336
Fix mypy errors
nkrishnaswami Mar 28, 2024
2c5ae11
Compat fixes for all pydantic 2 versions
nkrishnaswami Mar 29, 2024
41c096e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
3b6095b
remove leftover print statement
nkrishnaswami Mar 29, 2024
a2e1ef1
Fix mypy error
nkrishnaswami Mar 29, 2024
208edf5
Rename is_init_field_default to is_init_field_attrs since it's only u…
nkrishnaswami Mar 29, 2024
d961387
Optionalize pydantic Field(init) test
Mar 31, 2024
9adf98b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2024
f01a024
Fix optional support check when pydantic not installed
Mar 31, 2024
e82fbf9
Add pydantic dataclass tests with dataclasses.field init to restore c…
Apr 1, 2024
7bce4fa
Fix CodeQL warning
Apr 1, 2024
8c75405
Remove init=False from new testcases
Apr 1, 2024
c161dbe
Merge branch 'main' into fixes-for-dataclass-likes
mauvilsa Apr 5, 2024
1ccbb81
Move CHANGELOG entry to top
nkrishnaswami Apr 10, 2024
d876396
Merge branch 'main' into fixes-for-dataclass-likes
mauvilsa Apr 11, 2024
a722b76
Address review comments
nkrishnaswami Apr 11, 2024
519be54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
46e029f
fix test comment
nkrishnaswami Apr 11, 2024
37cc99a
Fix CI failure
nkrishnaswami Apr 11, 2024
edc8ed9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
4167311
Fix CI failures
nkrishnaswami Apr 11, 2024
0eb9d30
Merge branch 'main' into fixes-for-dataclass-likes
mauvilsa Apr 12, 2024
958b4ee
Merge branch 'main' into fixes-for-dataclass-likes
nkrishnaswami Apr 15, 2024
7c78d6b
Fix pydantic version error
nkrishnaswami Apr 15, 2024
7332328
Actually fix pydantic version error
nkrishnaswami Apr 15, 2024
773a2c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
326b90d
Remove debug logging
nkrishnaswami Apr 15, 2024
03ebbb2
Fix 3.12 failure without future annotations import
Apr 16, 2024
7ab0c80
Merge branch 'main' into fixes-for-dataclass-likes
nkrishnaswami Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ The semantic versioning only considers the public API as described in
:ref:`api-ref`. Components not mentioned in :ref:`api-ref` or different import
paths are considered internals and can change in minor and patch releases.

v4.28.1 (2024-0?-??)
--------------------

Fixed
^^^^^
- Failure to process ``Annotated`` dataclass members, and inclusion of
non-init fields in `attrs` and Pydantic dataclass-like instantiation.
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved

v4.28.0 (2024-03-??)
--------------------

Expand Down
21 changes: 20 additions & 1 deletion jsonargparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
)

from ._namespace import Namespace
from ._optionals import import_reconplogger, reconplogger_support
from ._optionals import (
get_alias_target,
import_reconplogger,
is_alias_type,
is_annotated,
reconplogger_support,
)
from ._type_checking import ArgumentParser

__all__ = [
Expand Down Expand Up @@ -99,6 +105,19 @@ def get_generic_origin(cls):
return cls.__origin__ if is_generic_class(cls) else cls


def get_unaliased_type(cls):
new_cls = cls
while True:
cur_cls = new_cls
if is_annotated(new_cls):
new_cls = new_cls.__origin__
if is_alias_type(new_cls):
new_cls = get_alias_target(new_cls)
if new_cls == cur_cls:
break
return cur_cls


def is_dataclass_like(cls) -> bool:
if is_generic_class(cls):
return is_dataclass_like(cls.__origin__)
Expand Down
31 changes: 30 additions & 1 deletion jsonargparse/_optionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

def typing_extensions_import(name):
if typing_extensions_support:
return getattr(__import__("typing_extensions"), name)
return getattr(__import__("typing_extensions"), name, False)
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
else:
return getattr(__import__("typing"), name, False)

Expand Down Expand Up @@ -313,6 +313,17 @@
return annotated_alias and isinstance(typehint, annotated_alias)


type_alias_type = typing_extensions_import("TypeAliasType")


def is_alias_type(typehint: type) -> bool:
return type_alias_type and isinstance(typehint, type_alias_type)


def get_alias_target(typehint: type) -> bool:
return typehint.__value__ # type: ignore[attr-defined]


def get_pydantic_support() -> int:
support = "0"
if find_spec("pydantic"):
Expand All @@ -331,6 +342,24 @@
pydantic_support = get_pydantic_support()


def get_pydantic_supports_field_init() -> bool:
if find_spec("pydantic"):
try:
from importlib.metadata import version

support = version("pydantic")
except ImportError:
import pydantic

support = pydantic.version.VERSION
major, minor = tuple(int(x) for x in support.split(".")[:2])
return major > 2 or (major == 2 and minor >= 4)
return False


pydantic_supports_field_init = get_pydantic_supports_field_init()
Dismissed Show dismissed Hide dismissed


def is_pydantic_model(class_type) -> int:
classes = inspect.getmro(class_type) if pydantic_support and inspect.isclass(class_type) else []
for cls in classes:
Expand Down
71 changes: 59 additions & 12 deletions jsonargparse/_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from ._common import LoggerProperty, get_generic_origin, is_dataclass_like, is_generic_class, is_subclass, parse_logger
from ._common import (
LoggerProperty,
get_generic_origin,
get_unaliased_type,
is_dataclass_like,
is_generic_class,
is_subclass,
parse_logger,
)
from ._optionals import is_pydantic_model, parse_docs
from ._postponed_annotations import evaluate_postponed_annotations
from ._stubs_resolver import get_stub_types
Expand Down Expand Up @@ -861,9 +869,26 @@ def get_field_data_pydantic1_model(field, name, doc_params):


def get_field_data_pydantic2_dataclass(field, name, doc_params):
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined

default = inspect._empty
if isinstance(field.default, FieldInfo):
# Pydantic 2 dataclasses stuff their FieldInfo into a
# stdlib dataclasses.field's `default`; this is where the
# actual default and default_factory live.
if field.default.default is not PydanticUndefined:
default = field.default.default
elif field.default.default_factory is not PydanticUndefined:
default = field.default.default_factory()
elif field.default is not dataclasses.MISSING:
default = field.default
elif field.default_factory is not dataclasses.MISSING:
default = field.default_factory()

return dict(
annotation=field.type,
default=field.default,
default=default,
Fixed Show fixed Hide fixed
doc=doc_params.get(name),
)

Expand Down Expand Up @@ -898,6 +923,23 @@ def get_field_data_attrs(field, name, doc_params):
)


def is_init_field_pydantic_model(field) -> bool:
return True
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved


def is_init_field_pydantic2_dataclass(field) -> bool:
from pydantic.fields import FieldInfo

if isinstance(field.default, FieldInfo):
# FieldInfo.init is new in pydantic 2.6
return getattr(field.default, "init", None) is not False
return field.init is not False


def is_init_field_attrs(field) -> bool:
return field.init is not False


def get_parameters_from_pydantic_or_attrs(
function_or_class: Union[Callable, Type],
method_or_property: Optional[str],
Expand All @@ -907,42 +949,47 @@ def get_parameters_from_pydantic_or_attrs(

if method_or_property or not (pydantic_support or attrs_support):
return None

function_or_class = get_unaliased_type(function_or_class)
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
fields_iterator = get_field_data = None
if pydantic_support:
pydantic_model = is_pydantic_model(function_or_class)
if pydantic_model == 1:
fields_iterator = function_or_class.__fields__.items() # type: ignore[union-attr]
fields_iterator = function_or_class.__fields__.items()
get_field_data = get_field_data_pydantic1_model
is_init_field = is_init_field_pydantic_model
elif pydantic_model > 1:
fields_iterator = function_or_class.model_fields.items() # type: ignore[union-attr]
fields_iterator = function_or_class.model_fields.items()
get_field_data = get_field_data_pydantic2_model
is_init_field = is_init_field_pydantic_model
elif dataclasses.is_dataclass(function_or_class) and hasattr(function_or_class, "__pydantic_fields__"):
fields_iterator = dataclasses.fields(function_or_class)
fields_iterator = {v.name: v for v in fields_iterator}.items()
get_field_data = get_field_data_pydantic2_dataclass
is_init_field = is_init_field_pydantic2_dataclass

if not fields_iterator and attrs_support:
import attrs

if attrs.has(function_or_class):
fields_iterator = {f.name: f for f in attrs.fields(function_or_class)}.items()
get_field_data = get_field_data_attrs
is_init_field = is_init_field_attrs

if not fields_iterator or not get_field_data:
return None

params = []
doc_params = parse_docs(function_or_class, None, logger)
for name, field in fields_iterator:
params.append(
ParamData(
name=name,
kind=kinds.KEYWORD_ONLY,
component=function_or_class,
**get_field_data(field, name, doc_params),
if is_init_field(field):
params.append(
ParamData(
name=name,
kind=kinds.KEYWORD_ONLY,
component=function_or_class,
**get_field_data(field, name, doc_params),
)
)
)
evaluate_postponed_annotations(params, function_or_class, None, logger)
return params

Expand Down
18 changes: 13 additions & 5 deletions jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union

from ._actions import _ActionConfigLoad
from ._common import LoggerProperty, get_class_instantiator, get_generic_origin, is_dataclass_like, is_subclass
from ._common import (
LoggerProperty,
get_class_instantiator,
get_generic_origin,
get_unaliased_type,
is_dataclass_like,
is_subclass,
)
from ._optionals import get_doc_short_description, is_pydantic_model, pydantic_support
from ._parameter_resolvers import (
ParamData,
Expand Down Expand Up @@ -72,7 +79,7 @@ def add_class_arguments(
ValueError: When not given a class.
ValueError: When there are required parameters without at least one valid type.
"""
if not inspect.isclass(get_generic_origin(theclass)):
if not inspect.isclass(get_generic_origin(get_unaliased_type(theclass))):
raise ValueError(f'Expected "theclass" parameter to be a class type, got: {theclass}.')
if default and not (isinstance(default, LazyInitBaseClass) and isinstance(default, theclass)):
raise ValueError(f'Expected "default" parameter to be a lazy instance of the class, got: {default}.')
Expand Down Expand Up @@ -133,9 +140,10 @@ def add_method_arguments(
ValueError: When not given a class or the name of a method of the class.
ValueError: When there are required parameters without at least one valid type.
"""
if not inspect.isclass(get_generic_origin(theclass)):
unaliased_type = get_unaliased_type(theclass)
if not inspect.isclass(get_generic_origin(unaliased_type)):
raise ValueError('Expected "theclass" argument to be a class object.')
if not hasattr(theclass, themethod) or not callable(getattr(theclass, themethod)):
if not hasattr(unaliased_type, themethod) or not callable(getattr(unaliased_type, themethod)):
raise ValueError('Expected "themethod" argument to be a callable member of the class.')

return self._add_signature_arguments(
Expand Down Expand Up @@ -440,7 +448,7 @@ def add_dataclass_arguments(
if isinstance(default, dict):
with suppress(TypeError):
default = theclass(**default)
if not isinstance(default, theclass):
if not isinstance(default, get_unaliased_type(theclass)):
raise ValueError(
f'Expected "default" argument to be an instance of "{theclass.__name__}" '
f"or its kwargs dict, given {default}"
Expand Down
16 changes: 13 additions & 3 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from ._common import (
get_class_instantiator,
get_unaliased_type,
is_dataclass_like,
is_subclass,
nested_links,
Expand All @@ -58,7 +59,9 @@
from ._namespace import Namespace
from ._optionals import (
argcomplete_warn_redraw_prompt,
get_alias_target,
get_files_completer,
is_alias_type,
is_annotated,
is_annotated_validator,
typing_extensions_import,
Expand Down Expand Up @@ -234,8 +237,8 @@ def prepare_add_argument(args, kwargs, enable_path, container, logger, sub_add_k
@staticmethod
def is_supported_typehint(typehint, full=False):
"""Whether the given type hint is supported."""
if is_annotated(typehint):
typehint = get_typehint_origin(typehint)
typehint = get_unaliased_type(typehint)

supported = (
typehint in root_types
or get_typehint_origin(typehint) in root_types
Expand Down Expand Up @@ -269,6 +272,7 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False):
typehint = typehint_from_action(typehint)
if typehint is None:
return False
typehint = get_unaliased_type(typehint)
typehint_origin = get_typehint_origin(typehint)
if typehint_origin == Union or (also_lists and typehint_origin in sequence_origin_types):
subtypes = [a for a in typehint.__args__ if a != NoneType]
Expand All @@ -287,7 +291,7 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False):

@staticmethod
def is_return_subclass_typehint(typehint):
typehint = get_optional_arg(typehint)
typehint = get_unaliased_type(get_optional_arg(get_unaliased_type(typehint)))
typehint_origin = get_typehint_origin(typehint)
if typehint_origin in callable_origin_types:
return_type = get_callable_return_type(typehint)
Expand All @@ -297,6 +301,7 @@ def is_return_subclass_typehint(typehint):

@staticmethod
def is_mapping_typehint(typehint):
typehint = get_unaliased_type(typehint)
typehint_origin = get_typehint_origin(typehint) or typehint
if (
typehint in mapping_origin_types
Expand All @@ -308,6 +313,7 @@ def is_mapping_typehint(typehint):

@staticmethod
def is_callable_typehint(typehint, all_subtypes=True):
typehint = get_unaliased_type(typehint)
typehint_origin = get_typehint_origin(typehint)
if typehint_origin == Union:
subtypes = [a for a in typehint.__args__ if a != NoneType]
Expand Down Expand Up @@ -908,6 +914,10 @@ def adapt_typehints(
error = indent_text(str(ex))
raise_unexpected_value(f"Problem with given class_path {class_path!r}:\n{error}", exception=ex)

# TypeAliasType -- 3.12 `type x = y` or manually via typing_extensions
elif is_alias_type(typehint):
return adapt_typehints(val, get_alias_target(typehint), **adapt_kwargs)

return val


Expand Down
Loading
Loading