Skip to content

Commit

Permalink
magicbot: Inject into component __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
auscompgeek committed Feb 8, 2023
1 parent cd5b663 commit 97c4b99
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
5 changes: 4 additions & 1 deletion magicbot/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ def get_injection_requests(
for n, inject_type in type_hints.items():
# If the variable is private ignore it
if n.startswith("_"):
if component is None:
message = f"Cannot inject into component {cname} __init__ param {n}"
raise MagicInjectError(message)
continue

# If the variable has been set, skip it
if hasattr(component, n):
if component is not None and hasattr(component, n):
continue

# Check for generic types from the typing module
Expand Down
26 changes: 18 additions & 8 deletions magicbot/magicrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ def _create_components(self) -> None:
# .. this hack is necessary for pybind11 based modules
sys.modules["pybind11_builtins"] = types.SimpleNamespace() # type: ignore

injectables = self._collect_injectables()

for m, ctyp in typing.get_type_hints(cls).items():
# Ignore private variables
if m.startswith("_"):
Expand All @@ -611,17 +613,16 @@ def _create_components(self) -> None:
% (cls.__name__, m, ctyp)
)

component = self._create_component(m, ctyp)
component = self._create_component(m, ctyp, injectables)

# Store for later
components.append((m, component))

self._injectables = self._collect_injectables()
injectables[m] = component

# For each new component, perform magic injection
for cname, component in components:
setup_tunables(component, cname, "components")
self._setup_vars(cname, component)
self._setup_vars(cname, component, injectables)
self._setup_reset_vars(component)

# Do it for autonomous modes too
Expand Down Expand Up @@ -672,9 +673,18 @@ def _collect_injectables(self) -> Dict[str, Any]:

return injectables

def _create_component(self, name: str, ctyp: type):
def _create_component(self, name: str, ctyp: type, injectables: Dict[str, Any]):
type_hints = typing.get_type_hints(ctyp.__init__)
NoneType = type(None)
init_return_type = type_hints.pop("return", NoneType)
assert (
init_return_type is NoneType
), f"{ctyp!r} __init__ had an unexpected non-None return type hint"
requests = get_injection_requests(type_hints, name)
injections = find_injections(requests, injectables, name)

# Create instance, set it on self
component = ctyp()
component = ctyp(**injections)
setattr(self, name, component)

# Ensure that mandatory methods are there
Expand All @@ -691,12 +701,12 @@ def _create_component(self, name: str, ctyp: type):

return component

def _setup_vars(self, cname: str, component) -> None:
def _setup_vars(self, cname: str, component, injectables: Dict[str, Any]) -> None:
self.logger.debug("Injecting magic variables into %s", cname)

type_hints = typing.get_type_hints(type(component))
requests = get_injection_requests(type_hints, cname, component)
injections = find_injections(requests, self._injectables, cname)
injections = find_injections(requests, injectables, cname)
component.__dict__.update(injections)

def _setup_reset_vars(self, component) -> None:
Expand Down
29 changes: 28 additions & 1 deletion tests/test_magicbot_injection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Type, TypeVar
from typing import List, Tuple, Type, TypeVar
from unittest.mock import Mock

import magicbot
Expand Down Expand Up @@ -33,12 +33,30 @@ def execute(self):
pass


class Component3:
intvar: int

def __init__(
self,
tupvar: Tuple[int, int],
injectable: Injectable,
component2: Component2,
) -> None:
self.tuple_ = tupvar
self.injectable_ = injectable
self.component_2 = component2

def execute(self):
pass


class SimpleBot(magicbot.MagicRobot):
intvar = 1
tupvar = 1, 2

component1: Component1
component2: Component2
component3: Component3

def createObjects(self):
self.injectable = Injectable(42)
Expand Down Expand Up @@ -158,6 +176,15 @@ def test_simple_annotation_inject():
assert bot.component2.tupvar == (1, 2)
assert bot.component2.component1 is bot.component1

assert bot.component3.intvar == 1
assert bot.component3.tuple_ == (1, 2)
assert isinstance(bot.component3.injectable_, Injectable)
assert bot.component3.injectable_.num == 42
assert bot.component3.component_2 is bot.component2

# Check the method hasn't been mutated
assert str(Component3.__init__.__annotations__["return"]) == "None"


def test_multilevel_annotation_inject():
bot = _make_bot(MultilevelBot)
Expand Down

0 comments on commit 97c4b99

Please sign in to comment.