Skip to content

Commit

Permalink
impl slice range instead of slice num
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 1, 2024
1 parent a06a6d4 commit e2c4d58
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion hrdae/conf
Submodule conf updated 1 files
+5 −4 experiment/model/vr.yaml
12 changes: 7 additions & 5 deletions hrdae/dataloaders/datasets/sliced_ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SlicedCTDatasetOption(DatasetOption):
motion_phase: str = "0"
motion_aggregation: str = "concat"
slice_axis: str = "y" # y or z
slice_num: int = MISSING
slice_range: list[int] = MISSING


def create_sliced_ct_dataset(
Expand All @@ -32,14 +32,14 @@ def slice_indexer(_: Tensor) -> Tensor:
return SlicedCT(
root=opt.root,
slice_indexer=slice_indexer,
slice_range=opt.slice_range,
transform=transform,
in_memory=opt.in_memory,
is_train=is_train,
content_phase=opt.content_phase,
motion_phase=opt.motion_phase,
motion_aggregation=opt.motion_aggregation,
slice_axis=opt.slice_axis,
slice_num=opt.slice_num,
)


Expand All @@ -51,14 +51,14 @@ def __init__(
self,
root: Path,
slice_indexer: Callable[[Tensor], Tensor],
slice_range: list[int],
transform: Transform | None = None,
in_memory: bool = True,
is_train: bool = True,
content_phase: str = "all",
motion_phase: str = "0",
motion_aggregation: str = "concat", # "concat" | "sum"
slice_axis: str = "y",
slice_num: int = 0,
) -> None:
super().__init__(
root=root,
Expand All @@ -71,7 +71,9 @@ def __init__(
motion_aggregation=motion_aggregation,
)
self.slice_axis = slice_axis
self.slice_num = slice_num
assert len(slice_range) == 2
self.slice_range = slice_range
self.slice_num = slice_range[1] - slice_range[0]

def __len__(self) -> int:
return len(self.paths) * self.slice_num
Expand All @@ -83,7 +85,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
assert "xp" in output # (n, _, d, h, w)
assert "xp_0" in output # (_, d, h, w)

slice_index = index % self.slice_num
slice_index = index % self.slice_num + self.slice_range[0]
if self.slice_axis == "y":
# slice by h
return {
Expand Down
2 changes: 1 addition & 1 deletion test/dataloaders/datasets/test_sliced_ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_SlicedCT():
motion_phase="0",
motion_aggregation="none",
slice_axis="z",
slice_num=10,
slice_range=[2, 8],
)
transform = transforms.Compose(
[
Expand Down

0 comments on commit e2c4d58

Please sign in to comment.