Skip to content

Commit 4344d18

Browse files
authored
Merge pull request #48 from valentingol/conditional
🆙 Update to 2.0.2
2 parents 734db96 + 16566db commit 4344d18

File tree

18 files changed

+284
-155
lines changed

18 files changed

+284
-155
lines changed

.pylintrc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ confidence=HIGH,
148148
disable=no-member,
149149
no-name-in-module,
150150
not-callable,
151-
redefined-outer-name
151+
redefined-outer-name,
152+
duplicate-code,
152153

153154
# Enable the message, report, category or checker with the given id(s). You can
154155
# either give multiple identifier separated by comma (,) or put this option

configs/default/models/cond_sagan.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@ data_size: 64 # could be either 32, 64, 128 or 256
66
# NOTE: the number of blocks depends on the data size only:
77
# data_size 32: 4 blocks, data_size 64: 5 blocks, data_size 128: 6 blocks, data_size 256: 7 blocks
88
attn_layer_num: [3, 4]
9-
full_values: False
109
z_dim: 128
1110
g_conv_dim: 64 # number of channels before the last layer of the generator
1211
d_conv_dim: 64 # number of channels after the first layer of the discriminator
1312

1413
# cond_dim_ratio: define the number of channels of conditional feature maps
1514
# cond_dim = base_dim // cond_dim_ratio
1615
cond_dim_ratio: 8
16+
17+
attention: !attention
18+
n_heads: 1 # number of attention heads
19+
out_layer: True # whether to apply a linear layer to the output of the attention
20+
qk_ratio: 8 # dimension for queries and keys is input_dim // qk_ratio
21+
v_ratio: 2 # dimension for values is input_dim // qk_ratio, should be 1 if out_layer=False

configs/default/models/sagan.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@ data_size: 64 # could be either 32, 64, 128 or 256
66
# NOTE: the number of blocks depends on the data size only:
77
# data_size 32: 4 blocks, data_size 64: 5 blocks, data_size 128: 6 blocks, data_size 256: 7 blocks
88
attn_layer_num: [3, 4]
9-
full_values: False
109
z_dim: 128
1110
g_conv_dim: 64 # number of channels before the last layer of the generator
1211
d_conv_dim: 64 # number of channels after the first layer of the discriminator
1312
cond_dim_ratio: -1 # no used here (only for conditional models)
13+
14+
attention: !attention
15+
n_heads: 1 # number of attention heads
16+
out_layer: True # whether to apply a linear layer to the output of the attention
17+
qk_ratio: 8 # dimension for queries and keys is input_dim // qk_ratio
18+
v_ratio: 2 # dimension for values is input_dim // qk_ratio, should be 1 if out_layer=False

configs/unittest/data32.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ model.attn_layer_num: [1, 2, 3]
1515
model.data_size: 32
1616
model.d_conv_dim: 8
1717
model.g_conv_dim: 8
18-
model.full_values: True
1918
model.init_method: orthogonal
19+
model.attention.n_heads: 1
20+
model.attention.out_layer: True
21+
model.attention.qk_ratio: 8
22+
model.attention.v_ratio: 2
2023

2124
training.adv_loss: wgan-gp
2225
training.ema_start_step: 0

configs/unittest/data64.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@ model.attn_layer_num: [4]
1313
model.data_size: 64
1414
model.d_conv_dim: 12
1515
model.g_conv_dim: 12
16-
model.full_values: False
1716
model.init_method: normal
17+
model.attention.n_heads: 4
18+
model.attention.out_layer: False
19+
model.attention.qk_ratio: 1
20+
model.attention.v_ratio: 1
1821

1922
training.adv_loss: wgan-gp
2023
training.ema_start_step: 0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# Installation
66
config = {
77
'name': 'sagan-facies-modeling',
8-
'version': '2.0.1',
8+
'version': '2.0.2',
99
'description': 'Facies modeling with SAGAN.',
1010
'author': 'Valentin Goldite',
1111
'author_email': '[email protected]',

tests/utils/conftest.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,64 @@
44

55
import numpy as np
66
import pytest
7+
import torch
78
from pytest_check import check_func
9+
from torch.utils.data import DataLoader
810

911
from utils.configs import GlobalConfig
12+
from utils.data.data_loader import DistributedDataLoader
13+
14+
15+
class DataLoader64(DistributedDataLoader):
16+
"""Data loader for unit tests (data size 64)."""
17+
18+
def __init__(self) -> None:
19+
# pylint: disable=super-init-not-called
20+
self.n_classes = 4
21+
22+
def loader(self) -> DataLoader:
23+
"""Return pytorch data loader."""
24+
25+
class Dataset64(torch.utils.data.Dataset):
26+
"""Dataset for unit tests (data size 32)."""
27+
28+
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
29+
return torch.randn(4, 64, 64), 0
30+
31+
def __len__(self) -> int:
32+
return 10
33+
34+
return torch.utils.data.DataLoader(dataset=Dataset64(),
35+
batch_size=2,
36+
shuffle=True,
37+
num_workers=0,
38+
)
39+
40+
41+
class DataLoader32(DistributedDataLoader):
42+
"""Data loader for unit tests (data size 32)."""
43+
44+
def __init__(self) -> None:
45+
# pylint: disable=super-init-not-called
46+
self.n_classes = 4
47+
48+
def loader(self) -> DataLoader:
49+
"""Return pytorch data loader."""
50+
51+
class Dataset32(torch.utils.data.Dataset):
52+
"""Dataset for unit tests (data size 32)."""
53+
54+
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
55+
return torch.randn(4, 32, 32), 0
56+
57+
def __len__(self) -> int:
58+
return 10
59+
60+
return torch.utils.data.DataLoader(dataset=Dataset32(),
61+
batch_size=2,
62+
shuffle=True,
63+
num_workers=0,
64+
)
1065

1166

1267
@check_func

tests/utils/gan/cond_sagan/test_cond_modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def test_sa_generator(configs: Tuple[GlobalConfig, GlobalConfig]) -> None:
1818
data, att_list = gen(z, pixel_maps, with_attn=True)
1919
check.equal(data.shape, (5, 4, 32, 32))
2020
check.equal(len(att_list), 3)
21-
check.equal(att_list[0].shape, (5, 16, 16))
22-
check.equal(att_list[1].shape, (5, 64, 64))
23-
check.equal(att_list[2].shape, (5, 256, 256))
21+
check.equal(att_list[0].shape, (5, 1, 16, 16))
22+
check.equal(att_list[1].shape, (5, 1, 64, 64))
23+
check.equal(att_list[2].shape, (5, 1, 256, 256))
2424
data = gen(z, pixel_maps, with_attn=False)
2525
check.is_instance(data, torch.Tensor)
2626

tests/utils/gan/test_attention.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,32 @@
33
import pytest_check as check
44
import torch
55

6+
from utils.configs import ConfigType
67
from utils.gan.attention import SelfAttention
78

89

9-
def test_self_attention() -> None:
10+
def test_self_attention(configs: ConfigType) -> None:
1011
"""Test SelfAttention."""
11-
attention = SelfAttention(in_dim=16, att_dim=8, full_values=False)
12+
config_32, config_64 = configs
13+
14+
# Case with out_layer = True
15+
attention = SelfAttention(in_dim=16,
16+
attention_config=config_32.model.attention)
1217
out, attention = attention(torch.rand(2, 16, 9, 9))
1318
check.equal(out.shape, (2, 16, 9, 9))
14-
check.equal(attention.shape, (2, 81, 81))
15-
attention = SelfAttention(in_dim=16, att_dim=None, full_values=True)
19+
check.equal(attention.shape, (2, 1, 81, 81))
20+
21+
# Case with out_layer = False
22+
attention = SelfAttention(in_dim=16,
23+
attention_config=config_64.model.attention)
1624
out, attention = attention(torch.rand(2, 16, 9, 9))
1725
check.equal(out.shape, (2, 16, 9, 9))
18-
check.equal(attention.shape, (2, 81, 81))
26+
check.equal(attention.shape, (2, 4, 81, 81))
27+
28+
# Case with out_layer = False and v_ratio != 1
29+
config_64_bis = config_64.copy()
30+
config_64_bis.merge({"model": {"attention": {"v_ratio": 2}}},
31+
do_not_pre_process=True)
32+
with check.raises(ValueError):
33+
SelfAttention(in_dim=16,
34+
attention_config=config_64_bis.model.attention)

tests/utils/gan/test_base_trainer.py

Lines changed: 24 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import os.path as osp
55
import shutil
6-
from typing import Dict, List, Tuple
6+
from typing import Dict, List, Optional, Tuple
77

88
import ignite.distributed as idist
99
import numpy as np
@@ -12,66 +12,12 @@
1212
from pytest_mock import MockerFixture
1313
from torch.nn import Module
1414
from torch.optim import Optimizer
15-
from torch.utils.data import DataLoader
1615

17-
from tests.utils.conftest import check_exists
16+
from tests.utils.conftest import DataLoader32, DataLoader64, check_exists
1817
from utils.configs import GlobalConfig
19-
from utils.data.data_loader import DistributedDataLoader
2018
from utils.gan.base_trainer import BaseTrainerGAN, BatchType
2119

2220

23-
class DataLoader64(DistributedDataLoader):
24-
"""Data loader for unit tests (data size 64)."""
25-
26-
def __init__(self) -> None:
27-
# pylint: disable=super-init-not-called
28-
self.n_classes = 4
29-
30-
def loader(self) -> DataLoader:
31-
"""Return pytorch data loader."""
32-
33-
class Dataset64(torch.utils.data.Dataset):
34-
"""Dataset for unit tests (data size 32)."""
35-
36-
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
37-
return torch.randn(4, 64, 64), 0
38-
39-
def __len__(self) -> int:
40-
return 10
41-
42-
return torch.utils.data.DataLoader(dataset=Dataset64(),
43-
batch_size=2,
44-
shuffle=True,
45-
num_workers=0,
46-
)
47-
48-
49-
class DataLoader32(DistributedDataLoader):
50-
"""Data loader for unit tests (data size 32)."""
51-
52-
def __init__(self) -> None:
53-
# pylint: disable=super-init-not-called
54-
self.n_classes = 4
55-
56-
def loader(self) -> DataLoader:
57-
"""Return pytorch data loader."""
58-
59-
class Dataset32(torch.utils.data.Dataset):
60-
"""Dataset for unit tests (data size 32)."""
61-
62-
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
63-
return torch.randn(4, 32, 32), 0
64-
65-
def __len__(self) -> int:
66-
return 10
67-
68-
return torch.utils.data.DataLoader(dataset=Dataset32(),
69-
batch_size=2,
70-
shuffle=True,
71-
num_workers=0,
72-
)
73-
74-
7521
class TrainerTest(BaseTrainerGAN):
7622
"""Test class for BaseTrainerGAN."""
7723
def train_generator(self, gen: Module, g_optimizer: Optimizer,
@@ -91,7 +37,7 @@ def train_discriminator(self, disc: Module, d_optimizer: Optimizer,
9137

9238
def build_model_opt(self) -> Tuple[Module, Module, Optimizer, Optimizer]:
9339

94-
class Generator(torch.nn.Module):
40+
class Generator(Module):
9541
"""Simple Generator"""
9642
def __init__(self) -> None:
9743
super().__init__()
@@ -148,7 +94,6 @@ def build_trainers() -> Tuple[BaseTrainerGAN, BaseTrainerGAN]:
14894

14995
# Test TrainerSAGAN
15096

151-
15297
def test_init() -> None:
15398
"""Test init method."""
15499
build_trainers()
@@ -174,8 +119,29 @@ def test_train(mocker: MockerFixture) -> None:
174119
np.save('tests/datasets/data32.npy', data32)
175120
np.save('tests/datasets/data64.npy', data64)
176121

122+
# Mock train_discriminator and train_generator
123+
def side_effect_disc(disc: Module, d_optimizer: Optimizer,
124+
*args: Optional[Tuple], **kwargs: Optional[Dict]
125+
) -> Tuple[Module, Optimizer,
126+
Dict[str, torch.Tensor]]:
127+
# pylint: disable=unused-argument
128+
losses = {'d_loss': (torch.tensor(0.3), 'red', 6)}
129+
return disc, d_optimizer, losses
130+
131+
def side_effect_gen(gen: Module, g_optimizer: Optimizer,
132+
*args: Optional[Tuple], **kwargs: Optional[Dict]
133+
) -> Tuple[Module, Optimizer,
134+
Dict[str, torch.Tensor]]:
135+
# pylint: disable=unused-argument
136+
losses = {'g_loss': (torch.tensor(0.2), 'green', 6)}
137+
return gen, g_optimizer, losses
138+
177139
trainers = build_trainers()
178140
for trainer in trainers:
141+
trainer.train_discriminator = mocker.MagicMock( # type: ignore
142+
side_effect=side_effect_disc)
143+
trainer.train_generator = mocker.MagicMock( # type: ignore
144+
side_effect=side_effect_gen)
179145
trainer.train()
180146
check_exists('res/tmp_test/models/generator_step_2.pth')
181147
check_exists('res/tmp_test/models/discriminator_step_2.pth')

0 commit comments

Comments
 (0)