diff --git a/injector/__init__.py b/injector/__init__.py index 5c9cf9c..7cac61d 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -1239,8 +1239,11 @@ def _is_new_union_type(instance: Any) -> bool: for k, v in list(bindings.items()): if _is_specialization(v, Annotated): - v, metadata = v.__origin__, v.__metadata__ - bindings[k] = v + origin, metadata = v.__origin__, v.__metadata__ + if ( + _inject_marker in metadata or _noinject_marker in metadata + ): # replace original annotated type with its origin if annotation is injection marker + bindings[k] = origin else: metadata = tuple() @@ -1324,7 +1327,7 @@ def provide_strs_also(self) -> List[str]: def _mark_provider_function(function: Callable, *, allow_multi: bool) -> None: scope_ = getattr(function, '__scope__', None) try: - annotations = get_type_hints(function) + annotations = get_type_hints(function, include_extras=True) except NameError: return_type = '__deferred__' else: diff --git a/injector_test.py b/injector_test.py index 3d98254..2cfc473 100644 --- a/injector_test.py +++ b/injector_test.py @@ -10,13 +10,11 @@ """Functional tests for the "Injector" dependency injection framework.""" -from contextlib import contextmanager -from typing import Any, NewType, Optional, Union import abc import sys import threading -import traceback -import warnings +from contextlib import contextmanager +from typing import Any, Optional, Union if sys.version_info >= (3, 9): from typing import Annotated @@ -1754,3 +1752,34 @@ def configure(binder): injector = Injector([configure]) assert injector.get(foo) == 123 assert injector.get(bar) == 456 + + +def test_annotated_injection_with_attribute(): + + foo = Annotated[str, "foo"] + bar = Annotated[str, "bar"] + + # noinspection PyUnusedLocal + @inject + def target(val_foo: foo, val_bar: bar): + pass + + assert get_bindings(target) == {'val_foo': foo, 'val_bar': bar} + + +def test_annotated_injection_from_provider_to_attribute(): + foo = Annotated[str, "foo"] + bar = Annotated[str, "bar"] + + class TestModule(Module): + @provider + def provide_foo(self) -> foo: + return "foo" + + @multiprovider + def provide_bars(self) -> List[bar]: + return ["bar"] + + injector = Injector([TestModule]) + assert injector.binder.has_explicit_binding_for(foo) + assert injector.binder.has_explicit_binding_for(List[bar])