Skip to content

Commit

Permalink
update sliced ct
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 1, 2024
1 parent 626d94d commit a06a6d4
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 7 deletions.
6 changes: 6 additions & 0 deletions hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
CTDatasetOption,
MNISTDatasetOption,
MovingMNISTDatasetOption,
SlicedCTDatasetOption,
)
from .dataloaders.transforms import (
Crop2dOption,
Expand Down Expand Up @@ -79,6 +80,11 @@
name="ct",
node=CTDatasetOption,
)
cs.store(
group="config/experiment/dataloader/dataset",
name="sliced_ct",
node=SlicedCTDatasetOption,
)
cs.store(
group="config/experiment/dataloader/transform",
name="to_tensor",
Expand Down
6 changes: 6 additions & 0 deletions hrdae/dataloaders/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from .option import DatasetOption
from .seq_divide_wrapper import SeqDivideWrapper
from .sliced_ct import SlicedCT, SlicedCTDatasetOption, create_sliced_ct_dataset


def create_dataset(
Expand All @@ -32,4 +33,9 @@ def create_dataset(
if not opt.sequential:
return SeqDivideWrapper(dataset, CT.PERIOD)
return dataset
if isinstance(opt, SlicedCTDatasetOption) and type(opt) is SlicedCTDatasetOption:
dataset = create_sliced_ct_dataset(opt, transform, is_train)
if not opt.sequential:
return SeqDivideWrapper(dataset, SlicedCT.PERIOD)
return dataset
raise NotImplementedError(f"dataset {opt.__class__.__name__} not implemented")
4 changes: 3 additions & 1 deletion hrdae/dataloaders/datasets/ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
super().__init__()

self.paths = []
data_root = root / self.__class__.__name__
data_root = root / "CT"
for i, path in enumerate(sorted(data_root.glob("**/*"))):
if is_train and i % (1 + self.TRAIN_PER_TEST) != 0:
self.paths.append(path)
Expand Down Expand Up @@ -151,13 +151,15 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
# (d, h, s) -> (s, d, h)
x_2d_0 = x_2d_0.permute(2, 0, 1)
x_2d_t = x_2d_t.permute(2, 0, 1)
x_2d_rt = x_2d_rt.unsqueeze(0)
# (2 * s, d, h)
x_2d_all = cat([x_2d_0, x_2d_t], dim=0)
# (n, d, h, w) -> (n, c, d, h, w)
x_3d = x_3d.unsqueeze(1)
# (d, h, w) -> (c, d, h, w)
x_3d_0 = x_3d_0.unsqueeze(0)
x_3d_t = x_3d_t.unsqueeze(0)
x_3d_rt = x_3d_rt.unsqueeze(0)
# (2 * c, d, h, w)
x_3d_all = cat([x_3d_0, x_3d_t], dim=0)

Expand Down
10 changes: 6 additions & 4 deletions hrdae/models/losses/mstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def forward(
cycled_latent: list[Tensor],
) -> Tensor:
assert len(latent) == len(cycled_latent)
return sum([ # type: ignore
torch.sqrt(((v1 - v2) ** 2).mean())
for v1, v2 in zip(latent, cycled_latent)
])
return sum(
[ # type: ignore
torch.sqrt(((v1 - v2) ** 2).mean())
for v1, v2 in zip(latent, cycled_latent)
]
)
2 changes: 1 addition & 1 deletion test/dataloaders/datasets/test_sliced_ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def test_SlicedCT():
with TemporaryDirectory() as root:
data_root = Path(root) / "SlicedCT"
data_root = Path(root) / "CT"
data_root.mkdir(parents=True, exist_ok=True)
np.savez(
data_root / "sample1.npz",
Expand Down

0 comments on commit a06a6d4

Please sign in to comment.