diff --git a/magicbot/inject.py b/magicbot/inject.py index 6f3c8e5..3af0ce8 100644 --- a/magicbot/inject.py +++ b/magicbot/inject.py @@ -23,10 +23,13 @@ def get_injection_requests( for n, inject_type in type_hints.items(): # If the variable is private ignore it if n.startswith("_"): + if component is None: + message = f"Cannot inject into component {cname} __init__ param {n}" + raise MagicInjectError(message) continue # If the variable has been set, skip it - if hasattr(component, n): + if component is not None and hasattr(component, n): continue # Check for generic types from the typing module diff --git a/magicbot/magicrobot.py b/magicbot/magicrobot.py index 69e34e2..9dcfdf2 100644 --- a/magicbot/magicrobot.py +++ b/magicbot/magicrobot.py @@ -595,6 +595,8 @@ def _create_components(self) -> None: # .. this hack is necessary for pybind11 based modules sys.modules["pybind11_builtins"] = types.SimpleNamespace() # type: ignore + injectables = self._collect_injectables() + for m, ctyp in typing.get_type_hints(cls).items(): # Ignore private variables if m.startswith("_"): @@ -611,17 +613,16 @@ def _create_components(self) -> None: % (cls.__name__, m, ctyp) ) - component = self._create_component(m, ctyp) + component = self._create_component(m, ctyp, injectables) # Store for later components.append((m, component)) - - self._injectables = self._collect_injectables() + injectables[m] = component # For each new component, perform magic injection for cname, component in components: setup_tunables(component, cname, "components") - self._setup_vars(cname, component) + self._setup_vars(cname, component, injectables) self._setup_reset_vars(component) # Do it for autonomous modes too @@ -672,9 +673,18 @@ def _collect_injectables(self) -> Dict[str, Any]: return injectables - def _create_component(self, name: str, ctyp: type): + def _create_component(self, name: str, ctyp: type, injectables: Dict[str, Any]): + type_hints = typing.get_type_hints(ctyp.__init__) + NoneType = type(None) + init_return_type = type_hints.pop("return", NoneType) + assert ( + init_return_type is NoneType + ), f"{ctyp!r} __init__ had an unexpected non-None return type hint" + requests = get_injection_requests(type_hints, name) + injections = find_injections(requests, injectables, name) + # Create instance, set it on self - component = ctyp() + component = ctyp(**injections) setattr(self, name, component) # Ensure that mandatory methods are there @@ -691,12 +701,12 @@ def _create_component(self, name: str, ctyp: type): return component - def _setup_vars(self, cname: str, component) -> None: + def _setup_vars(self, cname: str, component, injectables: Dict[str, Any]) -> None: self.logger.debug("Injecting magic variables into %s", cname) type_hints = typing.get_type_hints(type(component)) requests = get_injection_requests(type_hints, cname, component) - injections = find_injections(requests, self._injectables, cname) + injections = find_injections(requests, injectables, cname) component.__dict__.update(injections) def _setup_reset_vars(self, component) -> None: diff --git a/tests/test_magicbot_injection.py b/tests/test_magicbot_injection.py index 9dfa45c..f5cc692 100644 --- a/tests/test_magicbot_injection.py +++ b/tests/test_magicbot_injection.py @@ -1,4 +1,4 @@ -from typing import List, Type, TypeVar +from typing import List, Tuple, Type, TypeVar from unittest.mock import Mock import magicbot @@ -33,12 +33,30 @@ def execute(self): pass +class Component3: + intvar: int + + def __init__( + self, + tupvar: Tuple[int, int], + injectable: Injectable, + component2: Component2, + ) -> None: + self.tuple_ = tupvar + self.injectable_ = injectable + self.component_2 = component2 + + def execute(self): + pass + + class SimpleBot(magicbot.MagicRobot): intvar = 1 tupvar = 1, 2 component1: Component1 component2: Component2 + component3: Component3 def createObjects(self): self.injectable = Injectable(42) @@ -158,6 +176,15 @@ def test_simple_annotation_inject(): assert bot.component2.tupvar == (1, 2) assert bot.component2.component1 is bot.component1 + assert bot.component3.intvar == 1 + assert bot.component3.tuple_ == (1, 2) + assert isinstance(bot.component3.injectable_, Injectable) + assert bot.component3.injectable_.num == 42 + assert bot.component3.component_2 is bot.component2 + + # Check the method hasn't been mutated + assert str(Component3.__init__.__annotations__["return"]) == "None" + def test_multilevel_annotation_inject(): bot = _make_bot(MultilevelBot)