Skip to content

Commit 2c7bd71

Browse files
authored
Fix for tf_3dunet_barts workspace (#1197)
Signed-off-by: yes <[email protected]>
1 parent fc4c7ce commit 2c7bd71

File tree

5 files changed

+15
-16
lines changed

5 files changed

+15
-16
lines changed

openfl-workspace/tf_3dunet_brats/plan/cols.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@
22
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.
33

44
collaborators:
5-
- one

openfl-workspace/tf_3dunet_brats/plan/plan.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,6 @@ data_loader:
3636
template: src.tf_brats_dataloader.TensorFlowBratsDataLoader
3737
network:
3838
defaults: plan/defaults/network.yaml
39-
settings:
40-
agg_addr: DESKTOP-AOKV1IJ.localdomain
41-
agg_port: auto
42-
cert_folder: cert
43-
client_reconnect_interval: 5
44-
disable_client_auth: false
45-
disable_tls: false
46-
hash_salt: auto
47-
template: openfl.federation.Network
4839
task_runner:
4940
defaults: plan/defaults/task_runner.yaml
5041
settings:
@@ -80,4 +71,3 @@ tasks:
8071
epochs: 1
8172
metrics:
8273
- loss
83-
num_batches: 1
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
keras==2.13.1
12
nibabel
23
numpy
34

45
setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
5-
tensorflow>=2
6+
tensorflow==2.13.0

openfl-workspace/tf_3dunet_brats/src/dataloader.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,24 @@ def create_file_list(self):
5353
5454
Split into training and testing sets.
5555
"""
56-
searchpath = os.path.join(self.data_path, '*/*_seg.nii.gz')
56+
extension = '_seg.nii.gz'
57+
flair_extension = '_flair.nii.gz'
58+
searchpath = os.path.join(self.data_path, "*/*" + extension)
5759
filenames = tf.io.gfile.glob(searchpath)
5860

61+
# check for uncompressed files
62+
if not filenames:
63+
extension = '_seg.nii'
64+
flair_extension = '_flair.nii'
65+
searchpath = os.path.join(self.data_path, "*/*" + extension)
66+
filenames = tf.io.gfile.glob(searchpath)
67+
5968
# Create a dictionary of tuples with image filename and label filename
6069

6170
self.num_files = len(filenames)
6271
self.filenames = {}
6372
for idx, filename in enumerate(filenames):
64-
self.filenames[idx] = [filename.replace('_seg.nii.gz', '_flair.nii.gz'), filename]
73+
self.filenames[idx] = [filename.replace(extension, flair_extension), filename]
6574

6675
def z_normalize_img(self, img):
6776
"""

openfl-workspace/tf_3dunet_brats/src/tf_3dunet_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def create_model(self,
8080
initial_filters=initial_filters,
8181
batch_norm=batch_norm)
8282

83-
self.optimizer = tf.keras.optimizers.Adam()
83+
self.optimizer = tf.keras.optimizers.legacy.Adam()
8484

8585
model.compile(
8686
loss=dice_loss,
@@ -193,7 +193,7 @@ def create_model(self,
193193
)
194194

195195
model.compile(loss=dice_loss,
196-
optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
196+
optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.01),
197197
metrics=[dice_coef, soft_dice_coef]
198198
)
199199

0 commit comments

Comments
 (0)