You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When running apply_to_collection on a dataclass, cached properties do not get modified. This can cause subtle issues: for example, suppose I initialize a dataclass on CPU in a dataworker, and then move it onto GPU for a model batch. All of the dataclass fields that contain Tensors get moved correctly, but the cached_propertys continue to residue on the original device.
Steps to reproduce
importdataclassesfromfunctoolsimportcached_propertyimporttorchfromlightning_utilitiesimportapply_to_collectionfromtorchimportTensor@dataclasses.dataclassclassData:
a: Tensor@cached_propertydefb(self):
print("*"*10)
print("Computing and cache prop b")
print("*"*10)
returnself.a*2print("*"*10)
print("Data on CPU")
print("*"*10)
data=Data(a=torch.tensor([1, 2, 3], device="cuda"))
print(f"{data.a=}")
print(f"{data.a.device=}")
print(f"{data.b=}")
print(f"{data.b=}") # do this a second time to make sure we're caching itprint(f"{data.b.device=}")
print("*"*10)
print("Move Data to GPU")
print("*"*10)
new_data=apply_to_collection(data, Tensor, lambdax: x.to("cpu"))
print(f"{new_data.a=}")
print(f"{new_data.a.device=}")
print(f"{new_data.b=}")
print(f"{new_data.b=}") # do this a second time to make sure we're caching itprint(f"{new_data.b.device=}")
Yields the following output:
**********
Start with data on GPU
**********
data.a=tensor([1, 2, 3], device='cuda:0')
data.a.device=device(type='cuda', index=0)
**********
Computing and cache prop b
**********
data.b=tensor([2, 4, 6], device='cuda:0')
data.b=tensor([2, 4, 6], device='cuda:0')
data.b.device=device(type='cuda', index=0)
**********
Move Data to CPU
**********
new_data.a=tensor([1, 2, 3])
new_data.a.device=device(type='cpu')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b.device=device(type='cuda', index=0)
The text was updated successfully, but these errors were encountered:
Hey @jackdent
This is a rare use case and I won't have the bandwidth to look into it. We would be grateful for a contribution here if you're interested. The fix is probably to just reset the cache when running apply_to_collection.
Motivation
When running
apply_to_collection
on adataclass
, cached properties do not get modified. This can cause subtle issues: for example, suppose I initialize adataclass
on CPU in a dataworker, and then move it onto GPU for a model batch. All of thedataclass
fields that containTensor
s get moved correctly, but thecached_property
s continue to residue on the original device.Steps to reproduce
Yields the following output:
The text was updated successfully, but these errors were encountered: