Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from mrpro.utils.split_idx import split_idx
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted, ravel_multi_index, unsqueeze_tensors_left, unsqueeze_tensors_right
from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin
from mrpro.utils.getnested import getnestedattr, getnesteditem

__all__ = [
"TensorAttributeMixin",
"broadcast_right",
"fill_range_",
"getnestedattr",
"getnesteditem",
"ravel_multi_index",
"reduce_view",
"remove_repeat",
Expand Down
76 changes: 76 additions & 0 deletions src/mrpro/utils/getnested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Get a nested attribute."""

from collections.abc import Mapping
from typing import TypeVar, cast, overload

T = TypeVar('T')


@overload
def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: None = ...) -> object | None: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: None = ...) -> T: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: type[T] = ...) -> T | None: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: type[T] = ...) -> T: ...


def getnestedattr(obj: object, *attrs: str, default: T | None = None, return_type: type[T] | None = None) -> T | None:
"""
Get a nested attribute, or return a default if any step fails.

Parameters
----------
obj
object to get attribute from
attrs
attribute names to get
default
value to return if any step fails
return_type
type to cast the result to (only for type hinting)
"""
if return_type is not None and default is not None and not isinstance(default, return_type):
raise TypeError('default must be of the same type as return_type')
for attr in attrs:
try:
obj = getattr(obj, attr)
except AttributeError:
return default
return cast(T, obj)


@overload
def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: None = ...) -> object | None: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: None = ...) -> T: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: type[T] = ...) -> T | None: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: type[T] = ...) -> T: ...


def getnesteditem(obj: Mapping, *items: str, default: T | None = None, return_type: type[T] | None = None) -> T | None:
"""
Get a nested item, or return a default if any step fails.

Parameters
----------
obj
object to get attribute from
items
item names to get
default
value to return if any step fails
return_type
type to cast the result to (only for type hinting)
"""
if return_type is not None and default is not None and not isinstance(default, return_type):
raise TypeError('default must be of the same type as return_type')
for item in items:
try:
obj = obj[item]
except (KeyError, TypeError):
return default
return cast(T, obj)
86 changes: 86 additions & 0 deletions tests/utils/test_getnested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from dataclasses import dataclass, field

import pytest
from mrpro.utils import getnestedattr, getnesteditem
from typing_extensions import assert_type


@dataclass
class C:
"""Test class for getnestedattr."""

c: int = 1


@dataclass
class B:
"""Test class for getnestedattr."""

b: C = field(default_factory=C)


@dataclass
class A:
"""Test class for getnestedattr."""

a: B = field(default_factory=B)


def test_getnestedattr_value() -> None:
"""Test getnestedattr with a valid path."""
obj = A()
actual = getnestedattr(obj, 'a', 'b', 'c')
assert actual == 1


def test_getnestedattr_default() -> None:
"""Test getnestedattr with a missing path and a default value."""
obj = A()
actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', default=2)
assert_type(actual, int)
assert actual == 2


def test_getnestedattr_type() -> None:
"""Test getnestedattr with a missing path no default value, but a return type."""
obj = A()
actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', return_type=int)
assert_type(actual, int | None)
assert actual is None


def test_getnestedattr_default_type_error() -> None:
"""Test getnestedattr with a default value and a return type that do not match."""
obj = A()
with pytest.raises(TypeError):
getnestedattr(obj, 'a', default=2, return_type=str)


def test_getnesteditem_value() -> None:
"""Test getnesteditem with a valid path."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'b', 'c')
assert actual == 1


def test_getnesteditem_default() -> None:
"""Test getnesteditem with a missing path and a default value."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', default=2)
assert_type(actual, int)
assert actual == 2


def test_getnesteditem_type() -> None:
"""Test getnesteditem with a missing path no default value, but a return type."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', return_type=int)
assert_type(actual, int | None)
assert actual is None


def test_getnesteditem_default_type_error() -> None:
"""Test getnesteditem with a default value and a return type that do not match."""
obj = {'a': {'b': {'c': 1}}}
with pytest.raises(TypeError):
getnesteditem(obj, 'a', default=2, return_type=str)