Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply_to_collection doesn't work for cached properties #279

Open
jackdent opened this issue Jul 3, 2024 · 3 comments
Open

apply_to_collection doesn't work for cached properties #279

jackdent opened this issue Jul 3, 2024 · 3 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@jackdent
Copy link

jackdent commented Jul 3, 2024

Motivation

When running apply_to_collection on a dataclass, cached properties do not get modified. This can cause subtle issues: for example, suppose I initialize a dataclass on CPU in a dataworker, and then move it onto GPU for a model batch. All of the dataclass fields that contain Tensors get moved correctly, but the cached_propertys continue to residue on the original device.

Steps to reproduce

import dataclasses
from functools import cached_property

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor


@dataclasses.dataclass
class Data:
    a: Tensor

    @cached_property
    def b(self):
        print("*" * 10)
        print("Computing and cache prop b")
        print("*" * 10)
        return self.a * 2


print("*" * 10)
print("Data on CPU")
print("*" * 10)

data = Data(a=torch.tensor([1, 2, 3], device="cuda"))
print(f"{data.a=}")
print(f"{data.a.device=}")

print(f"{data.b=}")
print(f"{data.b=}")  # do this a second time to make sure we're caching it
print(f"{data.b.device=}")

print("*" * 10)
print("Move Data to GPU")
print("*" * 10)

new_data = apply_to_collection(data, Tensor, lambda x: x.to("cpu"))
print(f"{new_data.a=}")
print(f"{new_data.a.device=}")

print(f"{new_data.b=}")
print(f"{new_data.b=}")  # do this a second time to make sure we're caching it
print(f"{new_data.b.device=}")

Yields the following output:

**********
Start with data on GPU
**********
data.a=tensor([1, 2, 3], device='cuda:0')
data.a.device=device(type='cuda', index=0)
**********
Computing and cache prop b
**********
data.b=tensor([2, 4, 6], device='cuda:0')
data.b=tensor([2, 4, 6], device='cuda:0')
data.b.device=device(type='cuda', index=0)
**********
Move Data to CPU
**********
new_data.a=tensor([1, 2, 3])
new_data.a.device=device(type='cpu')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b.device=device(type='cuda', index=0)
@jackdent jackdent added enhancement New feature or request help wanted Extra attention is needed labels Jul 3, 2024
@jackdent
Copy link
Author

jackdent commented Jul 3, 2024

The Lightning apply_to_collection logic is defined here and relies on dataclass.fields, which doesn't include cached properties

@Borda
Copy link
Member

Borda commented Jul 13, 2024

@awaelchli, do you have any experience with this one?

@Borda Borda changed the title apply_to_collection doesn't work for cached properties apply_to_collection doesn't work for cached properties Jul 13, 2024
@awaelchli
Copy link
Contributor

Hey @jackdent
This is a rare use case and I won't have the bandwidth to look into it. We would be grateful for a contribution here if you're interested. The fix is probably to just reset the cache when running apply_to_collection.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants