Skip to content

Commit

Permalink
Merge pull request #269 from RaulPPelaez/simplify_data
Browse files Browse the repository at this point in the history
Fix inifinite recursion error on OSX
  • Loading branch information
RaulPPelaez authored Feb 1, 2024
2 parents 9619008 + c58bde7 commit 7686fd8
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,21 @@ def __init__(self, dataset, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__(
dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter
)
self.dataset = dataset
self.dtype = dtype
self._dataset = dataset
self._dtype = dtype

def len(self):
return len(self.dataset)
return len(self._dataset)

def get(self, idx):
data = self.dataset.get(idx)
data = self._dataset.get(idx)
for key, value in data:
if torch.is_tensor(value) and torch.is_floating_point(value):
setattr(data, key, value.to(self.dtype))
setattr(data, key, value.to(self._dtype))
return data

def __getattr__(self, name):
# Check if the attribute exists in the underlying dataset
if hasattr(self.dataset, name):
return getattr(self.dataset, name)
raise AttributeError(
f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'"
)
def __getattr__(self, __name):
return getattr(self.__dict__["_dataset"], __name)


class DataModule(LightningDataModule):
Expand Down

0 comments on commit 7686fd8

Please sign in to comment.