diff --git a/hrdae/dataloaders/datasets/ct.py b/hrdae/dataloaders/datasets/ct.py index 2ef65a5..e14c396 100644 --- a/hrdae/dataloaders/datasets/ct.py +++ b/hrdae/dataloaders/datasets/ct.py @@ -98,8 +98,6 @@ def __init__( if in_memory: for path in tqdm(self.paths, desc="loading datasets..."): t = from_numpy(np.load(path)["arr_0"]) - if transform is not None: - t = transform(t) self.data.append(t) self.slice_indexer = slice_indexer @@ -119,8 +117,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: else: # not in memory assert not self.in_memory x_3d = from_numpy(np.load(str(self.paths[index]))["arr_0"]) - if self.transform is not None: - x_3d = self.transform(x_3d) + if self.transform is not None: + x_3d = self.transform(x_3d) x_3d = x_3d.float() n, d, h, w = x_3d.size()