Skip to content

Commit 7686fd8

Browse files
authored
Merge pull request #269 from RaulPPelaez/simplify_data
Fix inifinite recursion error on OSX
2 parents 9619008 + c58bde7 commit 7686fd8

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

torchmdnet/data.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,21 @@ def __init__(self, dataset, dtype=torch.float64):
2525
super(FloatCastDatasetWrapper, self).__init__(
2626
dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter
2727
)
28-
self.dataset = dataset
29-
self.dtype = dtype
28+
self._dataset = dataset
29+
self._dtype = dtype
3030

3131
def len(self):
32-
return len(self.dataset)
32+
return len(self._dataset)
3333

3434
def get(self, idx):
35-
data = self.dataset.get(idx)
35+
data = self._dataset.get(idx)
3636
for key, value in data:
3737
if torch.is_tensor(value) and torch.is_floating_point(value):
38-
setattr(data, key, value.to(self.dtype))
38+
setattr(data, key, value.to(self._dtype))
3939
return data
4040

41-
def __getattr__(self, name):
42-
# Check if the attribute exists in the underlying dataset
43-
if hasattr(self.dataset, name):
44-
return getattr(self.dataset, name)
45-
raise AttributeError(
46-
f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'"
47-
)
41+
def __getattr__(self, __name):
42+
return getattr(self.__dict__["_dataset"], __name)
4843

4944

5045
class DataModule(LightningDataModule):

0 commit comments

Comments
 (0)