Skip to content

Commit

Permalink
Fix unable link two deep level arguments sharing the same root class (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa committed Oct 25, 2023
1 parent b571fc1 commit 7aa5e5c
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Fixed
- Failure to parse subclass added via add_argument and required arg as link
target.
- ``choices`` working incorrectly when ``nargs`` is ``+``, ``*`` or number.
- Unable link two deep level arguments sharing the same root class (`#297
<https://github.com/omni-us/jsonargparse/issues/297>`__).


v4.26.1 (2023-10-23)
Expand Down
4 changes: 3 additions & 1 deletion jsonargparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Dict, Generic, Optional, Tuple, Type, TypeVar, Union, _GenericAlias # type: ignore
from typing import Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, _GenericAlias # type: ignore

from ._namespace import Namespace
from ._type_checking import ArgumentParser
Expand Down Expand Up @@ -31,6 +31,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
lenient_check: ContextVar[Union[bool, str]] = ContextVar("lenient_check", default=False)
load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None)
class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators")
nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[])


parser_context_vars = dict(
Expand All @@ -40,6 +41,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
lenient_check=lenient_check,
load_value_mode=load_value_mode,
class_instantiators=class_instantiators,
nested_links=nested_links,
)


Expand Down
8 changes: 6 additions & 2 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def get_defaults(self, skip_check: bool = False) -> Namespace:
):
cfg[action.dest] = recreate_branches(action.default)

self._logger.debug("Loaded default values from parser")
self._logger.debug("Loaded parser defaults: %s", cfg)

default_config_files = self._get_default_config_files()
for key, default_config_file in default_config_files:
Expand Down Expand Up @@ -1165,7 +1165,11 @@ def instantiate_classes(
pass
else:
if value is not None:
with parser_context(parent_parser=self, class_instantiators=self._get_instantiators()):
with parser_context(
parent_parser=self,
nested_links=ActionLink.get_nested_links(self, component),
class_instantiators=self._get_instantiators(),
):
parent[key] = component.instantiate_classes(value)
else:
with parser_context(load_value_mode=self.parser_mode, class_instantiators=self._get_instantiators()):
Expand Down
82 changes: 62 additions & 20 deletions jsonargparse/_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,14 @@ def __init__(
if not hasattr(parser, "_links_group"):
parser._links_group = parser.add_argument_group("Linked arguments")
self.parser = parser
self._target = target
self._source = source = (source,) if isinstance(source, str) else source
self.apply_on = apply_on
self.compute_fn = compute_fn
self._initial_input_checks(source, target)

# Set and check source actions or group
exclude = (ActionLink, _ActionConfigLoad, _ActionSubCommands, ActionConfigFile)
source = (source,) if isinstance(source, str) else source
if apply_on == "instantiate":
self.source = [(s, find_subclass_action_or_class_group(parser, s, exclude=exclude)) for s in source]
for key, action in self.source:
Expand Down Expand Up @@ -219,13 +220,21 @@ def __init__(
help=help_str,
)

def get_kwargs(self) -> dict:
return {
"source": self._source,
"target": self._target,
"apply_on": self.apply_on,
"compute_fn": self.compute_fn,
}

def _initial_input_checks(self, source, target):
# Check apply_on
if self.apply_on not in {"parse", "instantiate"}:
raise ValueError("apply_on must be 'parse' or 'instantiate'.")

# Check compute function
if self.compute_fn is None and not isinstance(source, str):
if self.compute_fn is None and not (isinstance(source, str) or len(source) == 1):
raise ValueError("Multiple source keys requires a compute function.")

if self.apply_on == "parse":
Expand Down Expand Up @@ -266,21 +275,19 @@ def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None:
ActionLink.apply_parsing_links(subparser, cfg[subcommand]) # type: ignore
if not hasattr(parser, "_links_group"):
return
for action in parser._links_group._group_actions:
if action.apply_on != "parse":
continue
for action in get_link_actions(parser, "parse"):
from ._typehints import ActionTypeHint

args = []
skip_link = False
for source_key, source_action in action.source:
if ActionTypeHint.is_subclass_typehint(source_action[0]) and source_key not in cfg:
if ActionTypeHint.is_subclass_typehint(source_action[0]) and source_key not in cfg: # type: ignore
parser.logger.debug(
f"Link '{action.option_strings[0]}' ignored since source '{source_key}' not found in namespace."
)
skip_link = True
break
for source_action_n in [a for a in source_action if a.dest in cfg]:
for source_action_n in [a for a in source_action if a.dest in cfg]: # type: ignore
parser._check_value_key(source_action_n, cfg[source_action_n.dest], source_action_n.dest, None)
args.append(cfg[source_key])
if skip_link:
Expand Down Expand Up @@ -309,6 +316,7 @@ def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None:
# Compute value
value = action.call_compute_fn(args)
ActionLink.set_target_value(action, value, cfg, parser.logger)
parser.logger.debug(f"Applied link '{action.option_strings[0]}'.")

@staticmethod
def apply_instantiation_links(parser, cfg, target=None, order=None):
Expand All @@ -317,14 +325,15 @@ def apply_instantiation_links(parser, cfg, target=None, order=None):

applied_key = "__applied_instantiation_links__"
applied_links = cfg.pop(applied_key) if applied_key in cfg else set()
link_actions = [
a for a in parser._links_group._group_actions if a.apply_on == "instantiate" and a not in applied_links
]
link_actions = get_link_actions(parser, "instantiate", skip=applied_links)
if order and link_actions:
link_actions = ActionLink.reorder(order, link_actions)

for action in link_actions:
if not (order or action.target[0] == target or action.target[0].startswith(target + ".")):
target_key = action.target[0]
if not (
order or target_key == target or target_key.startswith(f"{target}.")
) or is_nested_instantiation_link(action):
continue
source_objects = []
for source_key, source_action in action.source:
Expand All @@ -350,10 +359,25 @@ def apply_instantiation_links(parser, cfg, target=None, order=None):
value = action.call_compute_fn(source_objects)
ActionLink.set_target_value(action, value, cfg, parser.logger)
applied_links.add(action)
parser.logger.debug(f"Applied link '{action.option_strings[0]}'.")

if target:
cfg[applied_key] = applied_links

@staticmethod
def get_nested_links(parser, action):
def trim_param_keys(params: dict):
params = params.copy()
params["source"] = tuple(k[len(f"{action.dest}.") :] for k in params["source"])
params["target"] = params["target"][len(f"{action.dest}.init_args.") :]
return params

links = []
for link in get_link_actions(parser, "instantiate"):
if link.target[1] is action and is_nested_instantiation_link(link):
links.append(trim_param_keys(link.get_kwargs()))
return links

@staticmethod
def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -> None:
target_key, target_action = action.target
Expand All @@ -378,15 +402,14 @@ def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -

@staticmethod
def instantiation_order(parser):
if hasattr(parser, "_links_group"):
actions = [a for a in parser._links_group._group_actions if a.apply_on == "instantiate"]
if len(actions) > 0:
graph = DirectedGraph()
for action in actions:
target = re.sub(r"\.init_args$", "", split_key_leaf(action.target[0])[0])
for _, source_action in action.source:
graph.add_edge(source_action.dest, target)
return graph.get_topological_order()
actions = get_link_actions(parser, "instantiate")
if actions:
graph = DirectedGraph()
for action in actions:
target = re.sub(r"\.init_args$", "", split_key_leaf(action.target[0])[0])
for _, source_action in action.source:
graph.add_edge(source_action.dest, target)
return graph.get_topological_order()
return []

@staticmethod
Expand Down Expand Up @@ -428,6 +451,25 @@ def del_target_key(target_key):
ActionLink.strip_link_target_keys(subparsers[num], cfg[subcommand])


def get_link_actions(parser: "ArgumentParser", apply_on: str, skip=set()) -> List[ActionLink]:
if not hasattr(parser, "_links_group"):
return []
return [a for a in parser._links_group._group_actions if a.apply_on == apply_on and a not in skip]


def is_nested_instantiation_link(action: ActionLink) -> bool:
from ._typehints import ActionTypeHint

target_key, target_action = action.target
assert target_action
return (
target_key.startswith(f"{target_action.dest}.init_args.")
and ActionTypeHint.is_subclass_typehint(target_action)
and all(a is target_action for _, a in action.source)
and all(k.startswith(f"{target_action.dest}.") for k, _ in action.source)
)


class ArgumentLinking:
"""Method for linking arguments."""

Expand Down
7 changes: 6 additions & 1 deletion jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
_is_action_value_list,
remove_actions,
)
from ._common import get_class_instantiator, is_dataclass_like, is_subclass, parent_parser, parser_context
from ._common import get_class_instantiator, is_dataclass_like, is_subclass, nested_links, parent_parser, parser_context
from ._loaders_dumpers import (
get_loader_exceptions,
load_value,
Expand Down Expand Up @@ -528,10 +528,15 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0):
parser = type(parser)(exit_on_error=False, logger=parser.logger)
remove_actions(parser, (ActionConfigFile, _ActionPrintConfig))
parser.add_class_arguments(val_class, **kwargs)

if "linked_targets" in kwargs and parser.required_args:
for key in kwargs["linked_targets"]:
if key in parser.required_args:
parser.required_args.remove(key)

for link_kwargs in nested_links.get():
parser.link_arguments(**link_kwargs)

return parser

def extra_help(self):
Expand Down
97 changes: 97 additions & 0 deletions jsonargparse_tests/test_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,103 @@ def test_on_instantiate_add_argument_subclass_required_params(parser):
assert init.cls2.a == 1


class WithinDeepSource:
def __init__(self, model_name: str):
self.output_channels = dict(
modelA=16,
modelB=32,
)[model_name]


class WithinDeepTarget:
def __init__(self, input_channels: int):
self.input_channels = input_channels


class WithinDeepModel:
def __init__(
self,
encoder: WithinDeepSource,
decoder: WithinDeepTarget,
):
self.encoder = encoder
self.decoder = decoder


within_deep_config = {
"model": {
"class_path": f"{__name__}.WithinDeepModel",
"init_args": {
"encoder": {
"class_path": f"{__name__}.WithinDeepSource",
"init_args": {
"model_name": "modelA",
},
},
"decoder": {
"class_path": f"{__name__}.WithinDeepTarget",
},
},
},
}


def test_on_instantiate_within_deep_subclass(parser, caplog):
parser.logger = {"level": "DEBUG"}
parser.logger.handlers = [caplog.handler]

parser.add_argument("--cfg", action=ActionConfigFile)
parser.add_argument("--model", type=WithinDeepModel)
parser.link_arguments(
"model.encoder.output_channels",
"model.init_args.decoder.init_args.input_channels",
apply_on="instantiate",
)

cfg = parser.parse_args([f"--cfg={within_deep_config}"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.model, WithinDeepModel)
assert isinstance(init.model.encoder, WithinDeepSource)
assert isinstance(init.model.decoder, WithinDeepTarget)
assert init.model.decoder.input_channels == 16
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text


class WithinDeeperSystem:
def __init__(self, model: WithinDeepModel):
self.model = model


within_deeper_config = {
"system": {
"class_path": f"{__name__}.WithinDeeperSystem",
"init_args": within_deep_config,
},
}


def test_on_instantiate_within_deeper_subclass(parser, caplog):
parser.logger = {"level": "DEBUG"}
parser.logger.handlers = [caplog.handler]

parser.add_argument("--cfg", action=ActionConfigFile)
parser.add_subclass_arguments(WithinDeeperSystem, "system")
parser.link_arguments(
"system.model.encoder.output_channels",
"system.init_args.model.init_args.decoder.init_args.input_channels",
apply_on="instantiate",
)

cfg = parser.parse_args([f"--cfg={within_deeper_config}"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.system, WithinDeeperSystem)
assert isinstance(init.system.model, WithinDeepModel)
assert isinstance(init.system.model.encoder, WithinDeepSource)
assert isinstance(init.system.model.decoder, WithinDeepTarget)
assert init.system.model.decoder.input_channels == 16
assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text


# link creation failures


Expand Down

0 comments on commit 7aa5e5c

Please sign in to comment.