Skip to content

Commit

Permalink
add rt option
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 25, 2024
1 parent 8313e41 commit 6384940
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion hrdae/conf
6 changes: 6 additions & 0 deletions hrdae/dataloaders/datasets/ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
n, d, h, w = x_3d.size()
assert n == self.PERIOD, f"expected {self.PERIOD} but got {n}"

rt = random.randint(0, self.PERIOD - 1)

# (s,)
slice_idx = self.slice_indexer(x_3d)
# (n, d, h, s)
Expand All @@ -138,9 +140,11 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
# (d, h, s)
x_2d_0 = x_2d[0]
x_2d_t = x_2d[self.PERIOD // 2]
x_2d_rt = x_2d[rt]
# (d, h, w)
x_3d_0 = x_3d[0]
x_3d_t = x_3d[self.PERIOD // 2]
x_3d_rt = x_3d[rt]

# (n, d, h, s) -> (n, s, d, h)
x_2d = x_2d.permute(0, 3, 1, 2)
Expand All @@ -163,10 +167,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
x_2d,
x_2d_0,
x_2d_t,
x_2d_rt,
x_2d_all,
x_3d,
x_3d_0,
x_3d_t,
x_3d_rt,
x_3d_all,
self.content_phase,
self.motion_phase,
Expand Down
6 changes: 6 additions & 0 deletions hrdae/dataloaders/datasets/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ def optimize_output(
x_2d: Tensor,
x_2d_0: Tensor,
x_2d_t: Tensor,
x_2d_rt: Tensor,
x_2d_all: Tensor,
x_3d: Tensor,
x_3d_0: Tensor,
x_3d_t: Tensor,
x_3d_rt: Tensor,
x_3d_all: Tensor,
content_phase: str,
motion_phase: str,
Expand All @@ -21,6 +23,8 @@ def optimize_output(
xp_0 = x_3d_0
elif content_phase == "t":
xp_0 = x_3d_t
elif content_phase == "random":
xp_0 = x_3d_rt
else:
raise KeyError(f"unknown content phase {content_phase}")

Expand All @@ -33,6 +37,8 @@ def optimize_output(
pass
elif motion_phase == "t":
xm_0 = x_2d_t
elif motion_phase == "random":
xm_0 = x_2d_rt
elif motion_phase == "all":
xm_0 = x_2d_all
else:
Expand Down
8 changes: 8 additions & 0 deletions hrdae/dataloaders/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from random import randint

from torch import Tensor, cat, gather, int64, tensor
from torch.utils.data import Dataset
Expand Down Expand Up @@ -33,6 +34,7 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:
x_2d = super().__getitem__(idx).squeeze(1)

n, h, w = x_2d.size()
rt = randint(0, self.PERIOD - 1)

# (s,)
slice_idx = tensor(self.slice_index, dtype=int64)
Expand All @@ -43,22 +45,26 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:
# (h, s)
x_1d_0 = x_1d[0]
x_1d_t = x_1d[self.PERIOD // 2]
x_1d_rt = x_1d[rt]
# (h, w)
x_2d_0 = x_2d[0]
x_2d_t = x_2d[self.PERIOD // 2]
x_2d_rt = x_2d[rt]

# (n, h, s) -> (n, s, h)
x_1d = x_1d.permute(0, 2, 1)
# (h, s) -> (s, h)
x_1d_0 = x_1d_0.permute(1, 0)
x_1d_t = x_1d_t.permute(1, 0)
x_1d_rt = x_1d_rt.permute(1, 0)
# (2 * s, h)
x_1d_all = cat([x_1d_0, x_1d_t], dim=0)
# (n, h, w) -> (n, c, h, w)
x_2d = x_2d.unsqueeze(1)
# (h, w) -> (c, h, w)
x_2d_0 = x_2d_0.unsqueeze(0)
x_2d_t = x_2d_t.unsqueeze(0)
x_2d_rt = x_2d_rt.unsqueeze(0)
# (2 * c, h, w)
x_2d_all = cat([x_2d_0, x_2d_t], dim=0)
# (n, h, s) -> (n, s, h)
Expand All @@ -68,10 +74,12 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:
x_1d,
x_1d_0,
x_1d_t,
x_1d_rt,
x_1d_all,
x_2d,
x_2d_0,
x_2d_t,
x_2d_rt,
x_2d_all,
self.content_phase,
self.motion_phase,
Expand Down

0 comments on commit 6384940

Please sign in to comment.