Skip to content

Commit 2e4a638

Browse files
author
tibuch
committed
Add 8-fold Test-Time Augmentation to predict, by default off.
1 parent 68fbe13 commit 2e4a638

File tree

3 files changed

+108
-49
lines changed

3 files changed

+108
-49
lines changed

n2v/models/n2v_standard.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from ..internals.N2V_DataWrapper import N2V_DataWrapper
2323
from ..internals.n2v_losses import loss_mse, loss_mae
2424
from ..utils import n2v_utils
25-
from ..utils.n2v_utils import pm_identity, pm_normal_additive, pm_normal_fitted, pm_normal_withoutCP, pm_uniform_withCP
25+
from ..utils.n2v_utils import pm_identity, pm_normal_additive, pm_normal_fitted, pm_normal_withoutCP, pm_uniform_withCP, \
26+
tta_forward, tta_backward
2627
from ..nets.unet import build_single_unet_per_channel
2728

2829
from tifffile import imsave
@@ -77,7 +78,7 @@ def __init__(self, config, name=None, basedir='.'):
7778

7879
config is None or isinstance(config, self._config_class) or _raise(
7980
ValueError("Invalid configuration of type '%s', was expecting type '%s'." % (
80-
type(config).__name__, self._config_class.__name__))
81+
type(config).__name__, self._config_class.__name__))
8182
)
8283
if config is not None and not config.is_valid():
8384
invalid_attr = config.is_valid(True)[1]
@@ -104,6 +105,7 @@ def __init__(self, config, name=None, basedir='.'):
104105
if config is None:
105106
self._find_and_load_weights()
106107

108+
107109
def _build(self):
108110
return self._build_unet(
109111
n_dim=self.config.n_dim,
@@ -301,7 +303,6 @@ def prepare_for_training(self, optimizer=None, **kwargs):
301303
self.callbacks.append(
302304
TensorBoard(log_dir=str(self.logdir / 'logs'), write_graph=False, profile_batch=0))
303305

304-
305306
if self.config.train_reduce_lr is not None:
306307
from keras.callbacks import ReduceLROnPlateau
307308
rlrop_params = self.config.train_reduce_lr
@@ -336,7 +337,7 @@ def __normalize__(self, data, means, stds):
336337
def __denormalize__(self, data, means, stds):
337338
return (data * stds) + means
338339

339-
def predict(self, img, axes, resizer=PadAndCropResizer(), n_tiles=None):
340+
def predict(self, img, axes, resizer=PadAndCropResizer(), n_tiles=None, tta=False):
340341
"""
341342
Apply the network to sofar unseen data. This method expects the raw data, i.e. not normalized.
342343
During prediction the mean and standard deviation, stored with the model (during data generation), are used
@@ -351,6 +352,8 @@ def predict(self, img, axes, resizer=PadAndCropResizer(), n_tiles=None):
351352
resizer : class(Resizer), optional(default=PadAndCropResizer())
352353
n_tiles : tuple(int)
353354
Number of tiles to tile the image into, if it is too large for memory.
355+
tta : bool
356+
Use test-time augmentation during prediction.
354357
355358
Returns
356359
-------
@@ -375,9 +378,17 @@ def predict(self, img, axes, resizer=PadAndCropResizer(), n_tiles=None):
375378
normalized = self.__normalize__(img[..., np.newaxis], means, stds)
376379
normalized = normalized[..., 0]
377380

378-
pred = \
379-
self._predict_mean_and_scale(normalized, axes=new_axes, normalizer=None, resizer=resizer, n_tiles=new_n_tiles)[
380-
0]
381+
if tta:
382+
aug = tta_forward(normalized)
383+
preds = []
384+
for img in aug:
385+
preds.append(self._predict_mean_and_scale(img, axes=new_axes, normalizer=None, resizer=resizer,
386+
n_tiles=new_n_tiles)[0])
387+
pred = tta_backward(preds)
388+
else:
389+
pred = \
390+
self._predict_mean_and_scale(normalized, axes=new_axes, normalizer=None, resizer=resizer,
391+
n_tiles=new_n_tiles)[0]
381392

382393
pred = self.__denormalize__(pred, means, stds)
383394

@@ -573,5 +584,3 @@ def get_yml_dict(self, name, description, authors, test_img, axes, patch_shape=N
573584
@property
574585
def _config_class(self):
575586
return N2VConfig
576-
577-

n2v/utils/n2v_utils.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
def get_subpatch(patch, coord, local_sub_patch_radius):
77
start = np.maximum(0, np.array(coord) - local_sub_patch_radius)
8-
end = start + local_sub_patch_radius*2 + 1
8+
end = start + local_sub_patch_radius * 2 + 1
99

1010
shift = np.minimum(0, patch.shape - end)
1111

1212
start += shift
1313
end += shift
1414

15-
slices = [ slice(s, e) for s, e in zip(start, end)]
15+
slices = [slice(s, e) for s, e in zip(start, end)]
1616

1717
return patch[tuple(slices)]
1818

@@ -40,17 +40,19 @@ def normal_withoutCP(patch, coords, dims):
4040
rand_coords = random_neighbor(patch.shape, coord)
4141
vals.append(patch[tuple(rand_coords)])
4242
return vals
43+
4344
return normal_withoutCP
4445

4546

4647
def pm_uniform_withCP(local_sub_patch_radius):
4748
def random_neighbor_withCP_uniform(patch, coords, dims):
4849
vals = []
4950
for coord in zip(*coords):
50-
sub_patch = get_subpatch(patch, coord,local_sub_patch_radius)
51+
sub_patch = get_subpatch(patch, coord, local_sub_patch_radius)
5152
rand_coords = [np.random.randint(0, s) for s in sub_patch.shape[0:dims]]
5253
vals.append(sub_patch[tuple(rand_coords)])
5354
return vals
55+
5456
return random_neighbor_withCP_uniform
5557

5658

@@ -60,6 +62,7 @@ def pixel_gauss(patch, coords, dims):
6062
for coord in zip(*coords):
6163
vals.append(np.random.normal(patch[tuple(coord)], pixel_gauss_sigma))
6264
return vals
65+
6366
return pixel_gauss
6467

6568

@@ -71,6 +74,7 @@ def local_gaussian(patch, coords, dims):
7174
axis = tuple(range(dims))
7275
vals.append(np.random.normal(np.mean(sub_patch, axis=axis), np.std(sub_patch, axis=axis)))
7376
return vals
77+
7478
return local_gaussian
7579

7680

@@ -80,17 +84,18 @@ def identity(patch, coords, dims):
8084
for coord in zip(*coords):
8185
vals.append(patch[coord])
8286
return vals
87+
8388
return identity
8489

8590

8691
def manipulate_val_data(X_val, Y_val, perc_pix=0.198, shape=(64, 64), value_manipulation=pm_uniform_withCP(5)):
8792
dims = len(shape)
8893
if dims == 2:
89-
box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int)
94+
box_size = np.round(np.sqrt(100 / perc_pix)).astype(np.int)
9095
get_stratified_coords = dw.__get_stratified_coords2D__
9196
rand_float = dw.__rand_float_coords2D__(box_size)
9297
elif dims == 3:
93-
box_size = np.round(np.sqrt(100/perc_pix)).astype(np.int)
98+
box_size = np.round(np.sqrt(100 / perc_pix)).astype(np.int)
9499
get_stratified_coords = dw.__get_stratified_coords3D__
95100
rand_float = dw.__rand_float_coords3D__(box_size)
96101

@@ -99,7 +104,7 @@ def manipulate_val_data(X_val, Y_val, perc_pix=0.198, shape=(64, 64), value_mani
99104
Y_val *= 0
100105
for j in tqdm(range(X_val.shape[0]), desc='Preparing validation data: '):
101106
coords = get_stratified_coords(rand_float, box_size=box_size,
102-
shape=np.array(X_val.shape)[1:-1])
107+
shape=np.array(X_val.shape)[1:-1])
103108
for c in range(n_chan):
104109
indexing = (j,) + coords + (c,)
105110
indexing_mask = (j,) + coords + (c + n_chan,)
@@ -112,17 +117,52 @@ def manipulate_val_data(X_val, Y_val, perc_pix=0.198, shape=(64, 64), value_mani
112117

113118

114119
def autocorrelation(x):
115-
"""
116-
nD autocorrelation
117-
remove mean per-patch (not global GT)
118-
normalize stddev to 1
119-
value at zero shift normalized to 1...
120-
"""
121-
x = (x - np.mean(x))/np.std(x)
122-
x = np.fft.fftn(x)
123-
x = np.abs(x)**2
124-
x = np.fft.ifftn(x).real
125-
x = x / x.flat[0]
126-
x = np.fft.fftshift(x)
127-
return x
128-
120+
"""
121+
nD autocorrelation
122+
remove mean per-patch (not global GT)
123+
normalize stddev to 1
124+
value at zero shift normalized to 1...
125+
"""
126+
x = (x - np.mean(x)) / np.std(x)
127+
x = np.fft.fftn(x)
128+
x = np.abs(x) ** 2
129+
x = np.fft.ifftn(x).real
130+
x = x / x.flat[0]
131+
x = np.fft.fftshift(x)
132+
return x
133+
134+
135+
def tta_forward(x):
136+
"""
137+
Augments x 8-fold: all 90 deg rotations plus lr flip of the four rotated versions.
138+
139+
Parameters
140+
----------
141+
x: data to augment
142+
143+
Returns
144+
-------
145+
Stack of augmented x.
146+
"""
147+
x_aug = [x, np.rot90(x, 1), np.rot90(x, 2), np.rot90(x, 3)]
148+
x_aug_flip = x_aug.copy()
149+
for x_ in x_aug:
150+
x_aug_flip.append(np.fliplr(x_))
151+
return x_aug_flip
152+
153+
154+
def tta_backward(x_aug):
155+
"""
156+
Inverts `tta_forward` and averages the 8 images.
157+
158+
Parameters
159+
----------
160+
x_aug: stack of 8-fold augmented images.
161+
162+
Returns
163+
-------
164+
average of de-augmented x_aug.
165+
"""
166+
x_deaug = [x_aug[0], np.rot90(x_aug[1], -1), np.rot90(x_aug[2], -2), np.rot90(x_aug[3], -3),
167+
np.fliplr(x_aug[4]), np.rot90(np.fliplr(x_aug[5]), -1), np.rot90(np.fliplr(x_aug[6]), -2), np.rot90(np.fliplr(x_aug[7]), -3)]
168+
return np.mean(x_deaug, 0)

tests/test_n2v_utils.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,46 @@
11
import numpy as np
22
from n2v.utils import n2v_utils
3+
from n2v.utils.n2v_utils import tta_forward, tta_backward
4+
35

46
def test_get_subpatch():
57
patch = np.arange(100)
6-
patch.shape = (10,10)
8+
patch.shape = (10, 10)
79

810
subpatch_target = np.array([[11, 12, 13, 14, 15],
911
[21, 22, 23, 24, 25],
1012
[31, 32, 33, 34, 35],
1113
[41, 42, 43, 44, 45],
1214
[51, 52, 53, 54, 55]])
1315

14-
subpatch_test = n2v_utils.get_subpatch(patch, (3,3), 2)
16+
subpatch_test = n2v_utils.get_subpatch(patch, (3, 3), 2)
1517

1618
assert np.sum(subpatch_target - subpatch_test) == 0
1719

18-
subpatch_test = n2v_utils.get_subpatch(patch, (3,3), 1)
20+
subpatch_test = n2v_utils.get_subpatch(patch, (3, 3), 1)
1921

2022
assert np.sum(subpatch_target[1:-1, 1:-1] - subpatch_test) == 0
2123

2224
patch = np.arange(1000)
23-
patch.shape = (10,10,10)
25+
patch.shape = (10, 10, 10)
2426

25-
subpatch_target = np.array([[[31,32,33],
26-
[41,42,43],
27-
[51,52,53]],
28-
[[131,132,133],
29-
[141,142,143],
30-
[151,152,153]],
31-
[[231,232,233],
32-
[241,242,243],
33-
[251,252,253]]])
27+
subpatch_target = np.array([[[31, 32, 33],
28+
[41, 42, 43],
29+
[51, 52, 53]],
30+
[[131, 132, 133],
31+
[141, 142, 143],
32+
[151, 152, 153]],
33+
[[231, 232, 233],
34+
[241, 242, 243],
35+
[251, 252, 253]]])
3436

35-
subpatch_test = n2v_utils.get_subpatch(patch, (1,4,2), 1)
37+
subpatch_test = n2v_utils.get_subpatch(patch, (1, 4, 2), 1)
3638

3739
assert np.sum(subpatch_target - subpatch_test) == 0
3840

3941

4042
def test_random_neighbor():
41-
coord = np.array([51,52,32])
43+
coord = np.array([51, 52, 32])
4244
shape = [128, 128, 128]
4345

4446
for i in range(1000):
@@ -54,9 +56,9 @@ def test_random_neighbor():
5456

5557
def test_pm_normal_neighbor_withoutCP():
5658
patch = np.arange(100)
57-
patch.shape = (10,10)
59+
patch.shape = (10, 10)
5860

59-
coords = (np.array([2, 4]), np.array([1,3]))
61+
coords = (np.array([2, 4]), np.array([1, 3]))
6062

6163
sampler = n2v_utils.pm_normal_withoutCP(1)
6264

@@ -68,7 +70,7 @@ def test_pm_normal_neighbor_withoutCP():
6870
patch = np.arange(1000)
6971
patch.shape = (10, 10, 10, 1)
7072

71-
coords = (np.array([2, 4, 6]), np.array([1,3,5]), np.array([3,5,1]))
73+
coords = (np.array([2, 4, 6]), np.array([1, 3, 5]), np.array([3, 5, 1]))
7274

7375
for i in range(100):
7476
val = sampler(patch, coords, len(patch.shape))
@@ -119,7 +121,7 @@ def test_pm_normal_additive():
119121

120122
val = sampler(patch, coords, len(patch.shape))
121123
for v, z, y, x in zip(val, *coords):
122-
assert v == patch[z,y,x]
124+
assert v == patch[z, y, x]
123125

124126

125127
def test_pm_normal_fitted():
@@ -162,4 +164,12 @@ def test_pm_identity():
162164

163165
val = sampler(patch, coords, len(patch.shape))
164166
for v, z, y, x in zip(val, *coords):
165-
assert v == patch[z, y, x]
167+
assert v == patch[z, y, x]
168+
169+
170+
def test_tta():
171+
img, _ = np.meshgrid(range(200), range(100))
172+
img[:50, :50] = 50
173+
aug = tta_forward(img[..., np.newaxis])
174+
avg = tta_backward(aug)
175+
assert np.sum(avg[..., 0] - img) == 0

0 commit comments

Comments
 (0)