Skip to content

Commit ffbd612

Browse files
author
tibuch
committed
Rerun BSD68 example with and without tta.
1 parent 2e4a638 commit ffbd612

File tree

1 file changed

+84
-16
lines changed

1 file changed

+84
-16
lines changed

examples/2D/denoising2D_BSD68/BSD68_reproducibility.ipynb

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@
184184
"name": "stderr",
185185
"output_type": "stream",
186186
"text": [
187-
"/home/tbuchhol/Gitrepos/n2v/n2v/models/n2v_standard.py:405: UserWarning: output path for model already exists, files may be overwritten: /home/tbuchhol/Gitrepos/n2v/examples/2D/denoising2D_BSD68/models/BSD68_reproducability_5x5\n",
187+
"/home/tbuchhol/Gitrepos/n2v/n2v/models/n2v_standard.py:416: UserWarning: output path for model already exists, files may be overwritten: /home/tbuchhol/Gitrepos/n2v/examples/2D/denoising2D_BSD68/models/BSD68_reproducability_5x5\n",
188188
" 'output path for model already exists, files may be overwritten: %s' % str(self.logdir.resolve()))\n",
189189
"Using TensorFlow backend.\n"
190190
]
@@ -396,13 +396,7 @@
396396
"400/400 [==============================] - 117s 292ms/step - loss: 0.1924 - val_loss: 0.1916\n",
397397
"Epoch 74/200\n",
398398
"400/400 [==============================] - 117s 292ms/step - loss: 0.1937 - val_loss: 0.1719\n",
399-
"Epoch 75/200\n"
400-
]
401-
},
402-
{
403-
"name": "stdout",
404-
"output_type": "stream",
405-
"text": [
399+
"Epoch 75/200\n",
406400
"400/400 [==============================] - 117s 292ms/step - loss: 0.1921 - val_loss: 0.1735\n",
407401
"\n",
408402
"Epoch 00075: ReduceLROnPlateau reducing learning rate to 9.999999747378752e-05.\n",
@@ -797,7 +791,7 @@
797791
"pred = []\n",
798792
"psnrs = []\n",
799793
"for gt, img in zip(groundtruth_data, test_data):\n",
800-
" p_ = model.predict(img.astype(np.float32), 'YX');\n",
794+
" p_ = model.predict(img.astype(np.float32), 'YX', tta=False);\n",
801795
" pred.append(p_)\n",
802796
" psnrs.append(PSNR(gt, p_))\n",
803797
"\n",
@@ -815,17 +809,54 @@
815809
"name": "stdout",
816810
"output_type": "stream",
817811
"text": [
818-
"PSNR: 27.81\n"
812+
"PSNR (without test-time augmentation): 27.81\n"
819813
]
820814
}
821815
],
822816
"source": [
823-
"print(\"PSNR:\", np.round(np.mean(psnrs), 2))"
817+
"print(\"PSNR (without test-time augmentation):\", np.round(np.mean(psnrs), 2))"
824818
]
825819
},
826820
{
827821
"cell_type": "code",
828822
"execution_count": 15,
823+
"metadata": {
824+
"scrolled": true
825+
},
826+
"outputs": [],
827+
"source": [
828+
"pred = []\n",
829+
"psnrs = []\n",
830+
"for gt, img in zip(groundtruth_data, test_data):\n",
831+
" p_ = model.predict(img.astype(np.float32), 'YX', tta=True);\n",
832+
" pred.append(p_)\n",
833+
" psnrs.append(PSNR(gt, p_))\n",
834+
"\n",
835+
"psnrs = np.array(psnrs)"
836+
]
837+
},
838+
{
839+
"cell_type": "code",
840+
"execution_count": 16,
841+
"metadata": {
842+
"scrolled": true
843+
},
844+
"outputs": [
845+
{
846+
"name": "stdout",
847+
"output_type": "stream",
848+
"text": [
849+
"PSNR (with test-time augmentation): 27.88\n"
850+
]
851+
}
852+
],
853+
"source": [
854+
"print(\"PSNR (with test-time augmentation):\", np.round(np.mean(psnrs), 2))"
855+
]
856+
},
857+
{
858+
"cell_type": "code",
859+
"execution_count": 17,
829860
"metadata": {},
830861
"outputs": [],
831862
"source": [
@@ -835,14 +866,14 @@
835866
},
836867
{
837868
"cell_type": "code",
838-
"execution_count": 16,
869+
"execution_count": 18,
839870
"metadata": {},
840871
"outputs": [],
841872
"source": [
842873
"pred = []\n",
843874
"psnrs = []\n",
844875
"for gt, img in zip(groundtruth_data, test_data):\n",
845-
" p_ = model.predict(img.astype(np.float32), 'YX')\n",
876+
" p_ = model.predict(img.astype(np.float32), 'YX', tta=False)\n",
846877
" pred.append(p_)\n",
847878
" psnrs.append(PSNR(gt, p_))\n",
848879
"\n",
@@ -851,7 +882,44 @@
851882
},
852883
{
853884
"cell_type": "code",
854-
"execution_count": 17,
885+
"execution_count": 19,
886+
"metadata": {
887+
"scrolled": true
888+
},
889+
"outputs": [
890+
{
891+
"name": "stdout",
892+
"output_type": "stream",
893+
"text": [
894+
"PSNR (without test-time augmentation): 27.77\n"
895+
]
896+
}
897+
],
898+
"source": [
899+
"print(\"PSNR (without test-time augmentation):\", np.round(np.mean(psnrs), 2))"
900+
]
901+
},
902+
{
903+
"cell_type": "code",
904+
"execution_count": 20,
905+
"metadata": {
906+
"scrolled": true
907+
},
908+
"outputs": [],
909+
"source": [
910+
"pred = []\n",
911+
"psnrs = []\n",
912+
"for gt, img in zip(groundtruth_data, test_data):\n",
913+
" p_ = model.predict(img.astype(np.float32), 'YX', tta=True)\n",
914+
" pred.append(p_)\n",
915+
" psnrs.append(PSNR(gt, p_))\n",
916+
"\n",
917+
"psnrs = np.array(psnrs)"
918+
]
919+
},
920+
{
921+
"cell_type": "code",
922+
"execution_count": 21,
855923
"metadata": {
856924
"scrolled": true
857925
},
@@ -860,12 +928,12 @@
860928
"name": "stdout",
861929
"output_type": "stream",
862930
"text": [
863-
"PSNR: 27.77\n"
931+
"PSNR (with test-time augmentation): 27.84\n"
864932
]
865933
}
866934
],
867935
"source": [
868-
"print(\"PSNR:\", np.round(np.mean(psnrs), 2))"
936+
"print(\"PSNR (with test-time augmentation):\", np.round(np.mean(psnrs), 2))"
869937
]
870938
}
871939
],

0 commit comments

Comments
 (0)