diff --git a/hrdae/__init__.py b/hrdae/__init__.py index 6cf43b3..5535101 100644 --- a/hrdae/__init__.py +++ b/hrdae/__init__.py @@ -5,6 +5,7 @@ CTDatasetOption, MNISTDatasetOption, MovingMNISTDatasetOption, + SlicedCTDatasetOption, ) from .dataloaders.transforms import ( Crop2dOption, @@ -79,6 +80,11 @@ name="ct", node=CTDatasetOption, ) +cs.store( + group="config/experiment/dataloader/dataset", + name="sliced_ct", + node=SlicedCTDatasetOption, +) cs.store( group="config/experiment/dataloader/transform", name="to_tensor", diff --git a/hrdae/conf b/hrdae/conf index 1606013..103ba58 160000 --- a/hrdae/conf +++ b/hrdae/conf @@ -1 +1 @@ -Subproject commit 16060133b06ee831b466a4480d94c8d2f290c0cd +Subproject commit 103ba5803b845099449e250300b2e3c4cbb9d923 diff --git a/hrdae/dataloaders/datasets/__init__.py b/hrdae/dataloaders/datasets/__init__.py index 2cee7a5..0ce8256 100644 --- a/hrdae/dataloaders/datasets/__init__.py +++ b/hrdae/dataloaders/datasets/__init__.py @@ -10,6 +10,7 @@ ) from .option import DatasetOption from .seq_divide_wrapper import SeqDivideWrapper +from .sliced_ct import SlicedCT, SlicedCTDatasetOption, create_sliced_ct_dataset def create_dataset( @@ -32,4 +33,9 @@ def create_dataset( if not opt.sequential: return SeqDivideWrapper(dataset, CT.PERIOD) return dataset + if isinstance(opt, SlicedCTDatasetOption) and type(opt) is SlicedCTDatasetOption: + dataset = create_sliced_ct_dataset(opt, transform, is_train) + if not opt.sequential: + return SeqDivideWrapper(dataset, SlicedCT.PERIOD) + return dataset raise NotImplementedError(f"dataset {opt.__class__.__name__} not implemented") diff --git a/hrdae/dataloaders/datasets/ct.py b/hrdae/dataloaders/datasets/ct.py index 022f369..2ef65a5 100644 --- a/hrdae/dataloaders/datasets/ct.py +++ b/hrdae/dataloaders/datasets/ct.py @@ -87,7 +87,7 @@ def __init__( super().__init__() self.paths = [] - data_root = root / self.__class__.__name__ + data_root = root / "CT" for i, path in enumerate(sorted(data_root.glob("**/*"))): if is_train and i % (1 + self.TRAIN_PER_TEST) != 0: self.paths.append(path) @@ -151,6 +151,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: # (d, h, s) -> (s, d, h) x_2d_0 = x_2d_0.permute(2, 0, 1) x_2d_t = x_2d_t.permute(2, 0, 1) + x_2d_rt = x_2d_rt.unsqueeze(0) # (2 * s, d, h) x_2d_all = cat([x_2d_0, x_2d_t], dim=0) # (n, d, h, w) -> (n, c, d, h, w) @@ -158,6 +159,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: # (d, h, w) -> (c, d, h, w) x_3d_0 = x_3d_0.unsqueeze(0) x_3d_t = x_3d_t.unsqueeze(0) + x_3d_rt = x_3d_rt.unsqueeze(0) # (2 * c, d, h, w) x_3d_all = cat([x_3d_0, x_3d_t], dim=0) diff --git a/hrdae/models/losses/mstd.py b/hrdae/models/losses/mstd.py index eaf9cdd..c0ea44d 100644 --- a/hrdae/models/losses/mstd.py +++ b/hrdae/models/losses/mstd.py @@ -28,7 +28,9 @@ def forward( cycled_latent: list[Tensor], ) -> Tensor: assert len(latent) == len(cycled_latent) - return sum([ # type: ignore - torch.sqrt(((v1 - v2) ** 2).mean()) - for v1, v2 in zip(latent, cycled_latent) - ]) + return sum( + [ # type: ignore + torch.sqrt(((v1 - v2) ** 2).mean()) + for v1, v2 in zip(latent, cycled_latent) + ] + ) diff --git a/test/dataloaders/datasets/test_sliced_ct.py b/test/dataloaders/datasets/test_sliced_ct.py index 1621692..3e18928 100644 --- a/test/dataloaders/datasets/test_sliced_ct.py +++ b/test/dataloaders/datasets/test_sliced_ct.py @@ -16,7 +16,7 @@ def test_SlicedCT(): with TemporaryDirectory() as root: - data_root = Path(root) / "SlicedCT" + data_root = Path(root) / "CT" data_root.mkdir(parents=True, exist_ok=True) np.savez( data_root / "sample1.npz",