Skip to content

Commit

Permalink
update mmnist
Browse files Browse the repository at this point in the history
  • Loading branch information
nnaakkaaii committed Jun 22, 2024
1 parent fbac68c commit 8acec63
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
9 changes: 7 additions & 2 deletions hrdae/dataloaders/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field

from torch import Tensor, cat, gather, int64, tensor
from torch import Tensor, cat, gather, int64, tensor, device
from torch.cuda import is_available
from torch.utils.data import Dataset
from torchvision import datasets

Expand Down Expand Up @@ -30,7 +31,11 @@ def __init__(self, *args, **kwargs) -> None:

def __getitem__(self, idx: int) -> dict[str, Tensor]:
# (n, h, w)
x_2d = super().__getitem__(idx).squeeze(1)
x_2d = self.data[idx].to(device("cuda:0") if is_available() else device("cpu"))
if self.transform is not None:
x_2d = self.transform(x_2d)
x_2d = x_2d.squeeze(1)

n, h, w = x_2d.size()

# (s,)
Expand Down
7 changes: 2 additions & 5 deletions hrdae/dataloaders/transforms/normalization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass

from torch import Tensor, device
from torch.cuda import is_available
from torch import Tensor

from .option import TransformOption

Expand All @@ -13,6 +12,4 @@ class MinMaxNormalizationOption(TransformOption):

class MinMaxNormalization:
def __call__(self, x: Tensor) -> Tensor:
x = x.to(device("cuda:0") if is_available() else device("cpu"))
x = (x - x.min()) / (x.max() - x.min()).cpu()
return x
return (x - x.min()) / (x.max() - x.min())
4 changes: 1 addition & 3 deletions hrdae/models/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from .r_dae import RDAE2dOption, RDAE3dOption, create_rdae2d, create_rdae3d


def create_network(
out_channels: int, opt: NetworkOption
) -> nn.Module:
def create_network(out_channels: int, opt: NetworkOption) -> nn.Module:
if (
isinstance(opt, AutoEncoder2dNetworkOption)
and type(opt) is AutoEncoder2dNetworkOption
Expand Down

0 comments on commit 8acec63

Please sign in to comment.