Skip to content

Commit

Permalink
Subclass types no longer allow class instance to be set as default (l…
Browse files Browse the repository at this point in the history
…ightning#18731).
  • Loading branch information
mauvilsa committed Oct 16, 2023
1 parent 2b20bc8 commit 46325b4
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 39 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ Fixed
- Failures with subcommands and default_config_files when keys are repeated
(`#160 <https://github.com/omni-us/jsonargparse/issues/160>`__).

Changed
^^^^^^^
- Subclass types no longer allow class instance to be set as default
(`lightning#18731
<https://github.com/Lightning-AI/lightning/issues/18731>`__).


v4.25.0 (2023-09-25)
--------------------
Expand Down
5 changes: 2 additions & 3 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,10 +884,9 @@ def set_defaults(self, *args: Dict[str, Any], **kwargs: Any) -> None:
raise KeyError(f'No action for destination key "{dest}" to set its default.')
elif isinstance(action, ActionConfigFile):
ActionConfigFile.set_default_error()
action.default = default
if isinstance(action, ActionTypeHint):
action.normalize_default()
self._defaults[dest] = action.default
default = action.normalize_default(default)
self._defaults[dest] = action.default = default
if kwargs:
self.set_defaults(kwargs)

Expand Down
3 changes: 2 additions & 1 deletion jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ def _add_signature_parameter(
}
if is_dataclass_like_typehint:
kwargs.update(sub_add_kwargs)
action = group.add_argument(*args, **kwargs)
with ActionTypeHint.allow_default_instance_context():
action = group.add_argument(*args, **kwargs)
action.sub_add_kwargs = sub_add_kwargs
if is_subclass_typehint and len(subclass_skip) > 0:
action.sub_add_kwargs["skip"] = subclass_skip
Expand Down
50 changes: 28 additions & 22 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
literal_types.add(__import__("typing").Literal)

subclass_arg_parser: ContextVar = ContextVar("subclass_arg_parser")
allow_default_instance: ContextVar = ContextVar("allow_default_instance", default=False)
sub_defaults: ContextVar = ContextVar("sub_defaults", default=False)


Expand Down Expand Up @@ -187,23 +188,26 @@ def __init__(self, typehint: Optional[Type] = None, enable_path: bool = False, *
kwargs["metavar"] = typehint_metavar(self._typehint)
super().__init__(**kwargs)
self._supports_append = self.supports_append(self._typehint)
self.normalize_default()
self.default = self.normalize_default(self.default)

def normalize_default(self):
default = self.default
def normalize_default(self, default):
is_subclass_type = self.is_subclass_typehint(self._typehint, all_subtypes=False)
if isinstance(default, LazyInitBaseClass):
self.default = default.lazy_get_init_data()
elif (
self.is_subclass_typehint(self._typehint, all_subtypes=False)
and isinstance(default, dict)
and "class_path" in default
):
self.default = subclass_spec_as_namespace(default)
self.default.class_path = normalize_import_path(self.default.class_path, self._typehint)
default = default.lazy_get_init_data()
elif is_subclass_type and isinstance(default, dict) and "class_path" in default:
default = subclass_spec_as_namespace(default)
default.class_path = normalize_import_path(default.class_path, self._typehint)
elif is_enum_type(self._typehint) and isinstance(default, Enum):
self.default = default.name
default = default.name
elif is_callable_type(self._typehint) and callable(default) and not inspect.isclass(default):
self.default = get_import_path(default)
default = get_import_path(default)
elif is_subclass_type and not allow_default_instance.get():
from ._parameter_resolvers import UnknownDefault

default_type = type(default)
if not is_subclass(default_type, UnknownDefault) and self.is_subclass_typehint(default_type):
raise ValueError("Subclass types require as default either a dict with class_path or a lazy instance.")
return default

@staticmethod
def prepare_add_argument(args, kwargs, enable_path, container, logger, sub_add_kwargs=None):
Expand Down Expand Up @@ -356,6 +360,15 @@ def subclass_arg_context(parser):
subclass_arg_parser.set(parser)
yield

@staticmethod
@contextmanager
def allow_default_instance_context():
token = allow_default_instance.set(True)
try:
yield
finally:
allow_default_instance.reset(token)

@staticmethod
@contextmanager
def sub_defaults_context():
Expand Down Expand Up @@ -1200,15 +1213,8 @@ def typehint_metavar(typehint):


def serialize_class_instance(val):
type_val = type(val)
val = str(val)
warning(
f"""
Not possible to serialize an instance of {type_val}. It will be
represented as the string {val}. If this was set as a default, consider
using lazy_instance.
"""
)
val = f"Unable to serialize instance {val}"
warning(val)
return val


Expand Down
48 changes: 35 additions & 13 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_parse_args_stderr,
get_parse_args_stdout,
get_parser_help,
source_unavailable,
)


Expand Down Expand Up @@ -1209,19 +1210,6 @@ def test_add_subclass_lazy_default(parser):
assert "'init_args': {'firstweekday': 5}" in help_str


def test_add_subclass_instance_default(parser):
parser.add_subclass_arguments(Calendar, "cal")
parser.set_defaults({"cal": Calendar(firstweekday=2)})
cfg = parser.parse_args([])
assert isinstance(cfg["cal"], Calendar)
init = parser.instantiate_classes(cfg)
assert init["cal"] is cfg["cal"]
with warnings.catch_warnings(record=True) as w:
dump = parser.dump(cfg)
assert "Not possible to serialize an instance of" in str(w[0].message)
assert "cal: <calendar.Calendar object at " in dump


class TupleBaseA:
def __init__(self, a1: int = 1, a2: float = 2.3):
self.a1 = a1
Expand Down Expand Up @@ -1267,6 +1255,40 @@ def test_add_subclass_not_required_group(parser):
assert init == Namespace()


# instance defaults tests


def test_add_subclass_set_defaults_instance_default(parser):
parser.add_subclass_arguments(Calendar, "cal")
with pytest.raises(ValueError) as ctx:
parser.set_defaults({"cal": Calendar(firstweekday=2)})
ctx.match("Subclass types require as default either a dict with class_path or a lazy instance")


def test_add_argument_subclass_instance_default(parser):
with pytest.raises(ValueError) as ctx:
parser.add_argument("--cal", type=Calendar, default=Calendar(firstweekday=2))
ctx.match("Subclass types require as default either a dict with class_path or a lazy instance")


class InstanceDefault:
def __init__(self, cal: Calendar = Calendar(firstweekday=2)):
pass


def test_subclass_signature_instance_default(parser):
with source_unavailable():
parser.add_class_arguments(InstanceDefault)
cfg = parser.parse_args([])
assert isinstance(cfg["cal"], Calendar)
init = parser.instantiate_classes(cfg)
assert init["cal"] is cfg["cal"]
with warnings.catch_warnings(record=True) as w:
dump = parser.dump(cfg)
assert "Unable to serialize instance" in str(w[0].message)
assert "cal: Unable to serialize instance <calendar.Calendar " in dump


# parameter skip tests


Expand Down

0 comments on commit 46325b4

Please sign in to comment.