Skip to content

Commit

Permalink
fix bugs in contrastive loss and mstd loss
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jul 1, 2024
1 parent cb52347 commit 741c9eb
Show file tree
Hide file tree
Showing 21 changed files with 318 additions and 107 deletions.
2 changes: 1 addition & 1 deletion hrdae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
)
cs.store(
group="config/experiment/model/loss",
name="perceptual2d",
name="perceptual",
node=Perceptual2dLossOption,
)
cs.store(
Expand Down
10 changes: 1 addition & 9 deletions hrdae/dataloaders/datasets/ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
x_3d_t = x_3d_t.unsqueeze(0)
# (2 * c, d, h, w)
x_3d_all = cat([x_3d_0, x_3d_t], dim=0)
# (n, d, h, s) -> (n, s, d, h)
idx_expanded = idx_expanded.permute(0, 3, 1, 2)

output = optimize_output(
return optimize_output(
x_2d,
x_2d_0,
x_2d_t,
Expand All @@ -178,9 +176,3 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
self.motion_phase,
self.motion_aggregation,
)
output["x"] = x_3d
output["t"] = x_3d
output["slice_idx"] = slice_idx
output["idx_expanded"] = idx_expanded

return output
10 changes: 1 addition & 9 deletions hrdae/dataloaders/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:
x_2d_rt = x_2d_rt.unsqueeze(0)
# (2 * c, h, w)
x_2d_all = cat([x_2d_0, x_2d_t], dim=0)
# (n, h, s) -> (n, s, h)
idx_expanded = idx_expanded.permute(0, 2, 1)

output = optimize_output(
return optimize_output(
x_1d,
x_1d_0,
x_1d_t,
Expand All @@ -85,12 +83,6 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:
self.motion_phase,
self.motion_aggregator,
)
output["x"] = x_2d
output["t"] = x_2d
output["slice_idx"] = slice_idx
output["idx_expanded"] = idx_expanded

return output


def create_moving_mnist_dataset(
Expand Down
103 changes: 103 additions & 0 deletions hrdae/dataloaders/datasets/sliced_ct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

from omegaconf import MISSING
from torch import Tensor, int64, tensor
from torch.utils.data import Dataset

from ..transforms import Transform
from .ct import CT
from .option import DatasetOption


@dataclass
class SlicedCTDatasetOption(DatasetOption):
root: Path = MISSING
slice_index: list[int] = MISSING
in_memory: bool = False
content_phase: str = "all"
motion_phase: str = "0"
motion_aggregation: str = "concat"
slice_axis: str = "y" # y or z
slice_num: int = MISSING


def create_sliced_ct_dataset(
opt: SlicedCTDatasetOption, transform: Transform, is_train: bool
) -> Dataset:
def slice_indexer(_: Tensor) -> Tensor:
return tensor(opt.slice_index, dtype=int64)

return SlicedCT(
root=opt.root,
slice_indexer=slice_indexer,
transform=transform,
in_memory=opt.in_memory,
is_train=is_train,
content_phase=opt.content_phase,
motion_phase=opt.motion_phase,
motion_aggregation=opt.motion_aggregation,
slice_axis=opt.slice_axis,
slice_num=opt.slice_num,
)


class SlicedCT(CT):
TRAIN_PER_TEST = 4
PERIOD = 10

def __init__(
self,
root: Path,
slice_indexer: Callable[[Tensor], Tensor],
transform: Transform | None = None,
in_memory: bool = True,
is_train: bool = True,
content_phase: str = "all",
motion_phase: str = "0",
motion_aggregation: str = "concat", # "concat" | "sum"
slice_axis: str = "y",
slice_num: int = 0,
) -> None:
super().__init__(
root=root,
slice_indexer=slice_indexer,
transform=transform,
in_memory=in_memory,
is_train=is_train,
content_phase=content_phase,
motion_phase=motion_phase,
motion_aggregation=motion_aggregation,
)
self.slice_axis = slice_axis
self.slice_num = slice_num

def __len__(self) -> int:
return len(self.paths) * self.slice_num

def __getitem__(self, index: int) -> dict[str, Tensor]:
output = super().__getitem__(index // self.slice_num)
assert "xm" in output # (n, _, d, h)
assert "xm_0" in output # (_, d, h)
assert "xp" in output # (n, _, d, h, w)
assert "xp_0" in output # (_, d, h, w)

slice_index = index % self.slice_num
if self.slice_axis == "y":
# slice by h
return {
"xm": output["xm"][:, :, :, slice_index],
"xm_0": output["xm_0"][:, :, slice_index],
"xp": output["xp"][:, :, :, slice_index],
"xp_0": output["xp_0"][:, :, slice_index],
}
elif self.slice_axis == "z":
# slice by d
return {
"xm": output["xm"][:, :, slice_index],
"xm_0": output["xm_0"][:, slice_index],
"xp": output["xp"][:, :, slice_index],
"xp_0": output["xp_0"][:, slice_index],
}
raise KeyError(f"unknown slice axis {self.slice_axis}")
14 changes: 6 additions & 8 deletions hrdae/models/basic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,11 @@ def train(
running_loss = 0.0

for idx, data in enumerate(train_loader):
assert "x" in data and "t" in data, 'Data must have keys "x" and "t"'

if max_iter and max_iter <= idx:
break

x = data["x"].to(self.device)
t = data["t"].to(self.device)
x = data["xp"].to(self.device)
t = data["xp"].to(self.device)

b, n = x.size()[:2]

Expand Down Expand Up @@ -114,8 +112,8 @@ def train(
if max_iter and max_iter <= idx:
break

x = data["x"].to(self.device)
t = data["t"].to(self.device)
x = data["xp"].to(self.device)
t = data["xp"].to(self.device)

b, n = x.size()[:2]

Expand Down Expand Up @@ -160,8 +158,8 @@ def train(
if epoch % 10 == 0:
data = next(iter(val_loader))

x = data["x"].to(self.device)
t = data["t"].to(self.device)
x = data["xp"].to(self.device)
t = data["xp"].to(self.device)

b, n = x.size()[:2]

Expand Down
20 changes: 15 additions & 5 deletions hrdae/models/gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,15 @@ def train(

# train generator
self.optimizer_g.zero_grad()
y, latent = self.generator(xm, xp_0, xm_0)
y, latent, cycled_latent = self.generator(xm, xp_0, xm_0)

y_pred = self.discriminator(y, xp)
loss_g_basic = self.criterion(y, xp, latent=latent)
loss_g_basic = self.criterion(
y,
xp,
latent=latent,
cycled_latent=cycled_latent,
)
loss_g_adv = self.criterion_g(y_pred, torch.ones_like(y_pred))

loss_g = loss_g_basic + 0.001 * loss_g_adv
Expand Down Expand Up @@ -172,12 +177,17 @@ def train(
xm_0 = data["xm_0"].to(self.device)
xp = data["xp"].to(self.device)
xp_0 = data["xp_0"].to(self.device)
y, latent = self.generator(xm, xp_0, xm_0)
y, cs, ds = self.generator(xm, xp_0, xm_0)

y_pred = self.discriminator(y, xp)
xp_pred = self.discriminator(xp, xp)
y = y.detach().clone()
loss_g_basic = self.criterion(y, xp, latent=latent)
loss_g_basic = self.criterion(
y,
xp,
latent=cs,
cycled_latent=ds,
)
loss_g_adv = self.criterion_g(y_pred, torch.ones_like(y_pred))
loss_g = loss_g_basic + loss_g_adv
loss_d_adv_real = self.criterion_d(
Expand Down Expand Up @@ -266,7 +276,7 @@ def train(
xp = data["xp"].to(self.device)
xp_0 = data["xp_0"].to(self.device)

y, _ = self.generator(xm, xp_0, xm_0)
y, _, _ = self.generator(xm, xp_0, xm_0)

save_reconstructed_images(
xp.data.cpu().clone().detach().numpy()[:10],
Expand Down
18 changes: 14 additions & 4 deletions hrdae/models/losses/mstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@ def create_mstd_loss() -> nn.Module:
class MStdLoss(nn.Module):
@property
def required_kwargs(self) -> list[str]:
return ["latent"]

def forward(self, input: Tensor, target: Tensor, latent: list[Tensor]) -> Tensor:
return sum([torch.sqrt(torch.mean(v**2)) for v in latent]) # type: ignore
return ["latent", "cycled_latent"]

def forward(
self,
input: Tensor,
target: Tensor,
latent: list[Tensor],
cycled_latent: list[Tensor],
) -> Tensor:
assert len(latent) == len(cycled_latent)
return sum([ # type: ignore
torch.sqrt(((v1 - v2) ** 2).mean())
for v1, v2 in zip(latent, cycled_latent)
])
29 changes: 15 additions & 14 deletions hrdae/models/networks/hr_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def forward(
x_1d: Tensor,
x_2d_0: Tensor,
x_1d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
) -> tuple[Tensor, list[Tensor], list[Tensor]]:
c, cs = self.content_encoder(x_2d_0)
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h, w = m.size()
Expand All @@ -318,7 +318,7 @@ def forward(
y = y.reshape(b, t, c_, h, w)
if self.activation is not None:
y = self.activation(y)
return y, [c] + cs
return y, [c] + cs, []


class CycleHRDAE2d(HRDAE2d):
Expand All @@ -327,18 +327,18 @@ def forward(
x_1d: Tensor,
x_2d_0: Tensor,
x_1d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
y, cs = super().forward(x_1d, x_2d_0, x_1d_0)
) -> tuple[Tensor, list[Tensor], list[Tensor]]:
y, cs, _ = super().forward(x_1d, x_2d_0, x_1d_0)
b, t, c, h, w = y.size()
y_seq = y.reshape(b * t, c, h, w)
d, ds = self.content_encoder(y_seq)
assert d.size(0) == b * t
d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1)
d = d.reshape(b, t, *d.size()[1:])
assert len(cs) == 1 + len(ds)
for i, di in enumerate(ds):
assert di.size(0) == b * t
ds[i] = di.reshape(b, t, *di.size()[1:]) - cs[1 + i].unsqueeze(1)
return y, [d] + ds
ds[i] = di.reshape(b, t, *di.size()[1:])
return y, [c.unsqueeze(1) for c in cs], [d] + ds


class HRDAE3d(nn.Module):
Expand Down Expand Up @@ -401,7 +401,7 @@ def forward(
x_2d: Tensor,
x_3d_0: Tensor,
x_2d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
) -> tuple[Tensor, list[Tensor], list[Tensor]]:
c, cs = self.content_encoder(x_3d_0)
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, d, h, w = m.size()
Expand All @@ -419,7 +419,7 @@ def forward(
y = y.reshape(b, t, c_, d, h, w)
if self.activation is not None:
y = self.activation(y)
return y, [c] + cs
return y, [c] + cs, []


class CycleHRDAE3d(HRDAE3d):
Expand All @@ -428,15 +428,16 @@ def forward(
x_2d: Tensor,
x_3d_0: Tensor,
x_2d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
y, cs = super().forward(x_2d, x_3d_0, x_2d_0)
) -> tuple[Tensor, list[Tensor], list[Tensor]]:
y, cs, _ = super().forward(x_2d, x_3d_0, x_2d_0)
b, t, c, d_, h, w = y.size()
y_seq = y.reshape(b * t, c, d_, h, w)
d, ds = self.content_encoder(y_seq)

assert d.size(0) == b * t
d = d.reshape(b, t, *d.size()[1:]) - cs[0].unsqueeze(1)
d = d.reshape(b, t, *d.size()[1:])
assert len(cs) == 1 + len(ds)
for i, di in enumerate(ds):
assert di.size(0) == b * t
ds[i] = di.reshape(b, t, *di.size()[1:]) - cs[1 + i].unsqueeze(1)
return y, [d] + ds
ds[i] = di.reshape(b, t, *di.size()[1:])
return y, [c.unsqueeze(1) for c in cs], [d] + ds
8 changes: 4 additions & 4 deletions hrdae/models/networks/r_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(
x_1d: Tensor,
x_2d_0: Tensor | None = None,
x_1d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
) -> tuple[Tensor, list[Tensor], list[Tensor]]:
m = self.motion_encoder(x_1d, x_1d_0)
b, t, c_, h, w = m.size()
m = m.reshape(b * t, c_, h, w)
Expand All @@ -117,7 +117,7 @@ def forward(
y = y.reshape(b, t, c_, h, w)
if self.activation is not None:
y = self.activation(y)
return y, []
return y, [], []


class RAE3d(nn.Module):
Expand Down Expand Up @@ -149,7 +149,7 @@ def forward(
x_2d: Tensor,
x_3d_0: Tensor | None = None,
x_2d_0: Tensor | None = None,
) -> tuple[Tensor, list[Tensor]]:
) -> tuple[Tensor, list[Tensor], list[Tensor]]:
m = self.motion_encoder(x_2d, x_2d_0)
b, t, c_, d, h, w = m.size()
m = m.reshape(b * t, c_, d, h, w)
Expand All @@ -161,4 +161,4 @@ def forward(
y = y.reshape(b, t, c_, d, h, w)
if self.activation is not None:
y = self.activation(y)
return y, []
return y, [], []
Loading

0 comments on commit 741c9eb

Please sign in to comment.