Skip to content

Commit 0e37674

Browse files
authored
Merge pull request #127 from ivankorobkov/auto-discover-attr-type-2
Implement class member type auto-discovery for inject.attr
2 parents fa3c62c + 29714bb commit 0e37674

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

src/inject/__init__.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def my_config(binder):
9494
else:
9595
_HAS_PEP560_SUPPORT = sys.version_info[:3] >= (3, 7, 0) # PEP 560
9696
_RETURN = 'return'
97+
_MISSING = object()
9798

9899
if _HAS_PEP604_SUPPORT:
99100
from types import UnionType
@@ -325,6 +326,10 @@ def __init__(self, cls: Type[T] | Hashable) -> None:
325326
doc="Return an attribute injection",
326327
)
327328

329+
def __set_name__(self, owner: Type[T], name: str) -> None:
330+
if self._cls is _MISSING:
331+
self._cls = _unwrap_cls_annotation(owner, name)
332+
328333

329334
class _ParameterInjection(Generic[T]):
330335
__slots__ = ('_name', '_cls')
@@ -522,13 +527,16 @@ def instance(cls: Binding) -> Injectable:
522527
"""Inject an instance of a class."""
523528
return get_injector_or_die().get_instance(cls)
524529

530+
@overload
531+
def attr() -> Injectable: ...
532+
525533
@overload
526534
def attr(cls: Hashable) -> Injectable: ...
527535

528536
@overload
529537
def attr(cls: Type[T]) -> T: ...
530538

531-
def attr(cls):
539+
def attr(cls=_MISSING):
532540
"""Return an attribute injection (descriptor)."""
533541
return _AttributeInjection(cls)
534542

@@ -653,3 +661,14 @@ def _is_union_type(typ):
653661
return (typ is Union or
654662
isinstance(typ, _GenericAlias) and typ.__origin__ is Union)
655663
return type(typ) is _Union
664+
665+
666+
def _unwrap_cls_annotation(cls: Type, attr_name: str):
667+
types = get_type_hints(cls)
668+
try:
669+
attr_type = types[attr_name]
670+
except KeyError:
671+
msg = f"Couldn't find type annotation for {attr_name}"
672+
raise InjectorException(msg)
673+
674+
return _unwrap_union_arg(attr_type)

test/test_attr.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@ class MyDataClass:
1212
class MyClass:
1313
field = inject.attr(int)
1414
field2: int = inject.attr(int)
15+
auto_typed_field: int = inject.attr()
1516

1617
inject.configure(lambda binder: binder.bind(int, 123))
1718
my = MyClass()
1819
my_dc = MyDataClass()
19-
value0 = my.field
20-
value1 = my.field
21-
value2 = my_dc.field
22-
value3 = my_dc.field
2320

24-
assert value0 == 123
25-
assert value1 == 123
26-
assert value2 == 123
27-
assert value3 == 123
21+
assert my.field == 123
22+
assert my.field == 123
23+
assert my.field2 == 123
24+
assert my.field2 == 123
25+
assert my.auto_typed_field == 123
26+
assert my.auto_typed_field == 123
27+
assert my_dc.field == 123
28+
assert my_dc.field == 123
2829

2930
def test_invalid_attachment_to_dataclass(self):
3031
@dataclass
@@ -36,21 +37,24 @@ class MyDataClass:
3637

3738
def test_class_attr(self):
3839
descriptor = inject.attr(int)
40+
auto_descriptor = inject.attr()
3941

4042
@dataclass
4143
class MyDataClass:
4244
field = descriptor
4345

4446
class MyClass(object):
4547
field = descriptor
48+
field2: int = descriptor
49+
auto_typed_field: int = auto_descriptor
4650

4751
inject.configure(lambda binder: binder.bind(int, 123))
48-
value0 = MyClass.field
49-
value1 = MyClass.field
50-
value2 = MyDataClass.field
51-
value3 = MyDataClass.field
52-
53-
assert value0 is descriptor
54-
assert value1 is descriptor
55-
assert value2 is descriptor
56-
assert value3 is descriptor
52+
53+
assert MyClass.field is descriptor
54+
assert MyClass.field is descriptor
55+
assert MyClass.field2 is descriptor
56+
assert MyClass.field2 is descriptor
57+
assert MyClass.auto_typed_field is auto_descriptor
58+
assert MyClass.auto_typed_field is auto_descriptor
59+
assert MyDataClass.field is descriptor
60+
assert MyDataClass.field is descriptor

0 commit comments

Comments
 (0)