Skip to content

Commit

Permalink
fix ct
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 16, 2024
1 parent 87cfe49 commit 93b433e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 72 deletions.
37 changes: 19 additions & 18 deletions hrdae/conf/experiment/dataloader/basic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ train_val_ratio: 0.8
# - pad2d
# - normalize2d
# - crop2d
# moving mnist
transform_order_train:
- min_max_normalization
transform_order_val:
- min_max_normalization
# # ct
# # moving mnist
# transform_order_train:
# - min_max_normalization
# - random_shift3d
# - uniform_shape3d
# - pool3d
# transform_order_val:
# - min_max_normalization
# - uniform_shape3d
# - pool3d
# ct
transform_order_train:
- min_max_normalization
- random_shift3d
- uniform_shape3d
- pool3d
transform_order_val:
- min_max_normalization
- uniform_shape3d
- pool3d
defaults:
- /config/experiment/dataloader/basic@_here_
# # mnist
Expand All @@ -34,11 +34,12 @@ defaults:
# - [email protected]: pad2d
# - [email protected]: normalize2d
# - [email protected]: crop2d
# moving mnist
- dataset: moving_mnist
- [email protected]_max_normalization: min_max_normalization
# # ct
# - dataset: ct
# # moving mnist
# - dataset: moving_mnist
# - [email protected]_max_normalization: min_max_normalization
# - [email protected]_shape3d: uniform_shape3d
# - [email protected]: pool3d
# ct
- dataset: ct
- [email protected]_max_normalization: min_max_normalization
- [email protected]_shift3d: random_shift3d
- [email protected]_shape3d: uniform_shape3d
- [email protected]: pool3d
37 changes: 3 additions & 34 deletions hrdae/dataloaders/datasets/ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from tqdm import tqdm

from ..transforms import Transform
from .option import DatasetOption


@dataclass
class CTDatasetOption:
class CTDatasetOption(DatasetOption):
root: Path = MISSING
threshold: float = 0.1
min_occupancy: float = 0.2
Expand Down Expand Up @@ -99,6 +100,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
if self.transform is not None:
x_3d = self.transform(x_3d)

x_3d = x_3d.float()
n, d, h, w = x_3d.size()
assert n == self.PERIOD, f"expected {self.PERIOD} but got {n}"

Expand Down Expand Up @@ -147,36 +149,3 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
"slice_idx": slice_idx, # (s)
"idx_expanded": idx_expanded, # (n, s, d, h)
}


if __name__ == "__main__":

def test():
from torchvision import transforms

from ..transforms import (
MinMaxNormalizationOption,
Pool3dOption,
UniformShape3dOption,
create_transform,
)

option = CTDatasetOption(
root=Path("data"),
)
dataset = create_ct_dataset(
option,
transform=transforms.Compose(
[
create_transform(MinMaxNormalizationOption()),
create_transform(UniformShape3dOption()),
create_transform(Pool3dOption()),
]
),
is_train=True,
)
data = dataset[0]
for k, v in data.items():
print(k, v.shape)

test()
20 changes: 0 additions & 20 deletions hrdae/dataloaders/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,3 @@ def create_moving_mnist_dataset(
download=True,
transform=transform,
)


if __name__ == "__main__":

def test():
from torchvision import transforms

option = MovingMNISTDatasetOption(
root="data",
)
dataset = create_moving_mnist_dataset(
option,
transform=transforms.Compose([]),
is_train=True,
)
data = dataset[0]
for k, v in data.items():
print(k, v.shape)

test()

0 comments on commit 93b433e

Please sign in to comment.