diff --git a/hrdae/conf b/hrdae/conf index 5419535..0e066ce 160000 --- a/hrdae/conf +++ b/hrdae/conf @@ -1 +1 @@ -Subproject commit 54195352a394541baa92a922e68b33c2f7129b56 +Subproject commit 0e066ce5b7bb03b9fd351c55bad58a5249d8a7e9 diff --git a/hrdae/dataloaders/datasets/ct.py b/hrdae/dataloaders/datasets/ct.py index 275d284..c842ed4 100644 --- a/hrdae/dataloaders/datasets/ct.py +++ b/hrdae/dataloaders/datasets/ct.py @@ -126,6 +126,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: n, d, h, w = x_3d.size() assert n == self.PERIOD, f"expected {self.PERIOD} but got {n}" + rt = random.randint(0, self.PERIOD - 1) + # (s,) slice_idx = self.slice_indexer(x_3d) # (n, d, h, s) @@ -138,9 +140,11 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: # (d, h, s) x_2d_0 = x_2d[0] x_2d_t = x_2d[self.PERIOD // 2] + x_2d_rt = x_2d[rt] # (d, h, w) x_3d_0 = x_3d[0] x_3d_t = x_3d[self.PERIOD // 2] + x_3d_rt = x_3d[rt] # (n, d, h, s) -> (n, s, d, h) x_2d = x_2d.permute(0, 3, 1, 2) @@ -163,10 +167,12 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: x_2d, x_2d_0, x_2d_t, + x_2d_rt, x_2d_all, x_3d, x_3d_0, x_3d_t, + x_3d_rt, x_3d_all, self.content_phase, self.motion_phase, diff --git a/hrdae/dataloaders/datasets/functions.py b/hrdae/dataloaders/datasets/functions.py index 256c044..f6b09d0 100644 --- a/hrdae/dataloaders/datasets/functions.py +++ b/hrdae/dataloaders/datasets/functions.py @@ -5,10 +5,12 @@ def optimize_output( x_2d: Tensor, x_2d_0: Tensor, x_2d_t: Tensor, + x_2d_rt: Tensor, x_2d_all: Tensor, x_3d: Tensor, x_3d_0: Tensor, x_3d_t: Tensor, + x_3d_rt: Tensor, x_3d_all: Tensor, content_phase: str, motion_phase: str, @@ -21,6 +23,8 @@ def optimize_output( xp_0 = x_3d_0 elif content_phase == "t": xp_0 = x_3d_t + elif content_phase == "random": + xp_0 = x_3d_rt else: raise KeyError(f"unknown content phase {content_phase}") @@ -33,6 +37,8 @@ def optimize_output( pass elif motion_phase == "t": xm_0 = x_2d_t + elif motion_phase == "random": + xm_0 = x_2d_rt elif motion_phase == "all": xm_0 = x_2d_all else: diff --git a/hrdae/dataloaders/datasets/moving_mnist.py b/hrdae/dataloaders/datasets/moving_mnist.py index 7c1effd..9eda5d5 100644 --- a/hrdae/dataloaders/datasets/moving_mnist.py +++ b/hrdae/dataloaders/datasets/moving_mnist.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from random import randint from torch import Tensor, cat, gather, int64, tensor from torch.utils.data import Dataset @@ -33,6 +34,7 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: x_2d = super().__getitem__(idx).squeeze(1) n, h, w = x_2d.size() + rt = randint(0, self.PERIOD - 1) # (s,) slice_idx = tensor(self.slice_index, dtype=int64) @@ -43,15 +45,18 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: # (h, s) x_1d_0 = x_1d[0] x_1d_t = x_1d[self.PERIOD // 2] + x_1d_rt = x_1d[rt] # (h, w) x_2d_0 = x_2d[0] x_2d_t = x_2d[self.PERIOD // 2] + x_2d_rt = x_2d[rt] # (n, h, s) -> (n, s, h) x_1d = x_1d.permute(0, 2, 1) # (h, s) -> (s, h) x_1d_0 = x_1d_0.permute(1, 0) x_1d_t = x_1d_t.permute(1, 0) + x_1d_rt = x_1d_rt.permute(1, 0) # (2 * s, h) x_1d_all = cat([x_1d_0, x_1d_t], dim=0) # (n, h, w) -> (n, c, h, w) @@ -59,6 +64,7 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: # (h, w) -> (c, h, w) x_2d_0 = x_2d_0.unsqueeze(0) x_2d_t = x_2d_t.unsqueeze(0) + x_2d_rt = x_2d_rt.unsqueeze(0) # (2 * c, h, w) x_2d_all = cat([x_2d_0, x_2d_t], dim=0) # (n, h, s) -> (n, s, h) @@ -68,10 +74,12 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: x_1d, x_1d_0, x_1d_t, + x_1d_rt, x_1d_all, x_2d, x_2d_0, x_2d_t, + x_2d_rt, x_2d_all, self.content_phase, self.motion_phase,