Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply func non roundtrippable seq #250

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Apply func non round-trippable seq ([#250](https://github.com/Lightning-AI/utilities/pull/250))


### Changed
Expand Down
10 changes: 9 additions & 1 deletion src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def is_dataclass_instance(obj: object) -> bool:
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)


def can_roundtrip_sequence(obj: Sequence) -> bool:
"""Check if sequence can be roundtripped."""
try:
return obj == type(obj)(list(obj)) # type: ignore[call-arg]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can you be sure that list(obj) will work for any sequence? Even though this is covered by the try-except block, the and conditions added to L129 will now skip the is_sequence loop for any sequence that this doesn't consider

Given the complexity and that your data structure is very specific to your problem. Could we instead expose an API so that you can have your data structure define whether it should be treated as a sequence or not?

For instance, that is similar to how pytrees are implemented which are much more extensible and performant: https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py#L163. In this case, it could also be an attribute of the data structure

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the complexity and that your data structure is very specific to your problem. Could we instead expose an API so that you can have your data structure define whether it should be treated as a sequence or not?

That sounds like the best solution. An exclude list for types that should just be passed through would be ideal. I can work on a PR in early June.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so how is it going here? :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still on the list. I'm just lacking time to work on open-source these days...

I'll close this PR and will open a new one with exlude support when I find the time.

except (TypeError, ValueError):
return False


def apply_to_collection(
data: Any,
dtype: Union[type, Any, Tuple[Union[type, Any]]],
Expand Down Expand Up @@ -118,7 +126,7 @@ def _apply_to_collection_slow(
return elem_type(OrderedDict(out))

is_namedtuple_ = is_namedtuple(data)
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
is_sequence = isinstance(data, Sequence) and not isinstance(data, str) and can_roundtrip_sequence(data)
if is_namedtuple_ or is_sequence:
out = []
for d in data:
Expand Down
10 changes: 10 additions & 0 deletions tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,13 @@ class Foo:
foo = Foo(0)
result = apply_to_collection(foo, int, lambda x: x + 1, allow_frozen=True)
assert foo == result


def test_apply_to_collection_non_roundtrippable_sequence():
class NonRoundtrippableSequence(list):
def __init__(self, x: int):
super().__init__(range(int(x)))

val = NonRoundtrippableSequence(3)
result = apply_to_collection(val, int, lambda x: x + 1)
assert val == result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't you expect result == [1, 2, 3]? What's a real example where this should become a no-op?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This came up when using: https://github.com/Bayer-Group/pado/blob/635d7b8b57e527254d6302730100a6dab5a2095f/pado/images/ids.py#L126-L351

Where ImageId instances are tuple subclasses but they don't roundtrip, i.e.:

iid = ImageId("a", "b", "c", site="site-1")

ImageId(list(iid)) <- is not allowed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @ap-- ^^

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it seems my reply was in pending state for a month and I had to still click on submit review 🤷

Loading