From e2c4d587eb7da1259b252bc3721fdcaf199b1ad7 Mon Sep 17 00:00:00 2001 From: nnaakkaaii Date: Tue, 2 Jul 2024 08:47:08 +0900 Subject: [PATCH] impl slice range instead of slice num --- hrdae/conf | 2 +- hrdae/dataloaders/datasets/sliced_ct.py | 12 +++++++----- test/dataloaders/datasets/test_sliced_ct.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/hrdae/conf b/hrdae/conf index 103ba58..a4b831f 160000 --- a/hrdae/conf +++ b/hrdae/conf @@ -1 +1 @@ -Subproject commit 103ba5803b845099449e250300b2e3c4cbb9d923 +Subproject commit a4b831f33e62bfe99707e4147d388ead5c2d3bc6 diff --git a/hrdae/dataloaders/datasets/sliced_ct.py b/hrdae/dataloaders/datasets/sliced_ct.py index 7c69c2c..3cb14b1 100644 --- a/hrdae/dataloaders/datasets/sliced_ct.py +++ b/hrdae/dataloaders/datasets/sliced_ct.py @@ -20,7 +20,7 @@ class SlicedCTDatasetOption(DatasetOption): motion_phase: str = "0" motion_aggregation: str = "concat" slice_axis: str = "y" # y or z - slice_num: int = MISSING + slice_range: list[int] = MISSING def create_sliced_ct_dataset( @@ -32,6 +32,7 @@ def slice_indexer(_: Tensor) -> Tensor: return SlicedCT( root=opt.root, slice_indexer=slice_indexer, + slice_range=opt.slice_range, transform=transform, in_memory=opt.in_memory, is_train=is_train, @@ -39,7 +40,6 @@ def slice_indexer(_: Tensor) -> Tensor: motion_phase=opt.motion_phase, motion_aggregation=opt.motion_aggregation, slice_axis=opt.slice_axis, - slice_num=opt.slice_num, ) @@ -51,6 +51,7 @@ def __init__( self, root: Path, slice_indexer: Callable[[Tensor], Tensor], + slice_range: list[int], transform: Transform | None = None, in_memory: bool = True, is_train: bool = True, @@ -58,7 +59,6 @@ def __init__( motion_phase: str = "0", motion_aggregation: str = "concat", # "concat" | "sum" slice_axis: str = "y", - slice_num: int = 0, ) -> None: super().__init__( root=root, @@ -71,7 +71,9 @@ def __init__( motion_aggregation=motion_aggregation, ) self.slice_axis = slice_axis - self.slice_num = slice_num + assert len(slice_range) == 2 + self.slice_range = slice_range + self.slice_num = slice_range[1] - slice_range[0] def __len__(self) -> int: return len(self.paths) * self.slice_num @@ -83,7 +85,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: assert "xp" in output # (n, _, d, h, w) assert "xp_0" in output # (_, d, h, w) - slice_index = index % self.slice_num + slice_index = index % self.slice_num + self.slice_range[0] if self.slice_axis == "y": # slice by h return { diff --git a/test/dataloaders/datasets/test_sliced_ct.py b/test/dataloaders/datasets/test_sliced_ct.py index 3e18928..a08ca99 100644 --- a/test/dataloaders/datasets/test_sliced_ct.py +++ b/test/dataloaders/datasets/test_sliced_ct.py @@ -32,7 +32,7 @@ def test_SlicedCT(): motion_phase="0", motion_aggregation="none", slice_axis="z", - slice_num=10, + slice_range=[2, 8], ) transform = transforms.Compose( [