3
3
import os
4
4
import os .path as osp
5
5
import shutil
6
- from typing import Dict , List , Tuple
6
+ from typing import Dict , List , Optional , Tuple
7
7
8
8
import ignite .distributed as idist
9
9
import numpy as np
12
12
from pytest_mock import MockerFixture
13
13
from torch .nn import Module
14
14
from torch .optim import Optimizer
15
- from torch .utils .data import DataLoader
16
15
17
- from tests .utils .conftest import check_exists
16
+ from tests .utils .conftest import DataLoader32 , DataLoader64 , check_exists
18
17
from utils .configs import GlobalConfig
19
- from utils .data .data_loader import DistributedDataLoader
20
18
from utils .gan .base_trainer import BaseTrainerGAN , BatchType
21
19
22
20
23
- class DataLoader64 (DistributedDataLoader ):
24
- """Data loader for unit tests (data size 64)."""
25
-
26
- def __init__ (self ) -> None :
27
- # pylint: disable=super-init-not-called
28
- self .n_classes = 4
29
-
30
- def loader (self ) -> DataLoader :
31
- """Return pytorch data loader."""
32
-
33
- class Dataset64 (torch .utils .data .Dataset ):
34
- """Dataset for unit tests (data size 32)."""
35
-
36
- def __getitem__ (self , index : int ) -> Tuple [torch .Tensor , int ]:
37
- return torch .randn (4 , 64 , 64 ), 0
38
-
39
- def __len__ (self ) -> int :
40
- return 10
41
-
42
- return torch .utils .data .DataLoader (dataset = Dataset64 (),
43
- batch_size = 2 ,
44
- shuffle = True ,
45
- num_workers = 0 ,
46
- )
47
-
48
-
49
- class DataLoader32 (DistributedDataLoader ):
50
- """Data loader for unit tests (data size 32)."""
51
-
52
- def __init__ (self ) -> None :
53
- # pylint: disable=super-init-not-called
54
- self .n_classes = 4
55
-
56
- def loader (self ) -> DataLoader :
57
- """Return pytorch data loader."""
58
-
59
- class Dataset32 (torch .utils .data .Dataset ):
60
- """Dataset for unit tests (data size 32)."""
61
-
62
- def __getitem__ (self , index : int ) -> Tuple [torch .Tensor , int ]:
63
- return torch .randn (4 , 32 , 32 ), 0
64
-
65
- def __len__ (self ) -> int :
66
- return 10
67
-
68
- return torch .utils .data .DataLoader (dataset = Dataset32 (),
69
- batch_size = 2 ,
70
- shuffle = True ,
71
- num_workers = 0 ,
72
- )
73
-
74
-
75
21
class TrainerTest (BaseTrainerGAN ):
76
22
"""Test class for BaseTrainerGAN."""
77
23
def train_generator (self , gen : Module , g_optimizer : Optimizer ,
@@ -91,7 +37,7 @@ def train_discriminator(self, disc: Module, d_optimizer: Optimizer,
91
37
92
38
def build_model_opt (self ) -> Tuple [Module , Module , Optimizer , Optimizer ]:
93
39
94
- class Generator (torch . nn . Module ):
40
+ class Generator (Module ):
95
41
"""Simple Generator"""
96
42
def __init__ (self ) -> None :
97
43
super ().__init__ ()
@@ -148,7 +94,6 @@ def build_trainers() -> Tuple[BaseTrainerGAN, BaseTrainerGAN]:
148
94
149
95
# Test TrainerSAGAN
150
96
151
-
152
97
def test_init () -> None :
153
98
"""Test init method."""
154
99
build_trainers ()
@@ -174,8 +119,29 @@ def test_train(mocker: MockerFixture) -> None:
174
119
np .save ('tests/datasets/data32.npy' , data32 )
175
120
np .save ('tests/datasets/data64.npy' , data64 )
176
121
122
+ # Mock train_discriminator and train_generator
123
+ def side_effect_disc (disc : Module , d_optimizer : Optimizer ,
124
+ * args : Optional [Tuple ], ** kwargs : Optional [Dict ]
125
+ ) -> Tuple [Module , Optimizer ,
126
+ Dict [str , torch .Tensor ]]:
127
+ # pylint: disable=unused-argument
128
+ losses = {'d_loss' : (torch .tensor (0.3 ), 'red' , 6 )}
129
+ return disc , d_optimizer , losses
130
+
131
+ def side_effect_gen (gen : Module , g_optimizer : Optimizer ,
132
+ * args : Optional [Tuple ], ** kwargs : Optional [Dict ]
133
+ ) -> Tuple [Module , Optimizer ,
134
+ Dict [str , torch .Tensor ]]:
135
+ # pylint: disable=unused-argument
136
+ losses = {'g_loss' : (torch .tensor (0.2 ), 'green' , 6 )}
137
+ return gen , g_optimizer , losses
138
+
177
139
trainers = build_trainers ()
178
140
for trainer in trainers :
141
+ trainer .train_discriminator = mocker .MagicMock ( # type: ignore
142
+ side_effect = side_effect_disc )
143
+ trainer .train_generator = mocker .MagicMock ( # type: ignore
144
+ side_effect = side_effect_gen )
179
145
trainer .train ()
180
146
check_exists ('res/tmp_test/models/generator_step_2.pth' )
181
147
check_exists ('res/tmp_test/models/discriminator_step_2.pth' )
0 commit comments