Skip to content

Commit d2f8ef3

Browse files
authored
Merge pull request #112 from Leengit/num_threads_1
BUG: Limit `private_threadpool_size` to 1 to keep Tensorflow from freezing on GPU
2 parents 0ffc381 + 19dafa3 commit d2f8ef3

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

histomics_stream/tensorflow.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,21 @@ def __init__(self):
3131
"deterministic": False,
3232
}
3333

34-
def __call__(self, study_description, num_workers=1, worker_index=0):
34+
def __call__(
35+
self,
36+
study_description,
37+
num_workers=None,
38+
worker_index=None,
39+
private_threadpool_size=None,
40+
):
3541
"""
3642
From scratch, creates a tensorflow dataset with one tensorflow element per tile
3743
"""
44+
num_workers = num_workers if num_workers is not None else 1
45+
worker_index = worker_index if worker_index is not None else 0
46+
private_threadpool_size = (
47+
private_threadpool_size if private_threadpool_size is not None else 1
48+
)
3849

3950
# Call to superclass to find the locations for the chunks
4051
# print(f"Build chunks: begin {datetime.datetime.now()}")
@@ -134,6 +145,17 @@ def __call__(self, study_description, num_workers=1, worker_index=0):
134145
lambda elem: (elem, None, None), **self.dataset_map_options
135146
)
136147
# print(f"Build study_dataset pop: end {datetime.datetime.now()}")
148+
149+
# By default `private_threadpool_size` is set to 0, which means that Tensorflow
150+
# is free to choose the number without limit. However, Tensorflow can grind to
151+
# a halt when processing a large dataset with this default behavior on GPU. A
152+
# value of 1 for `private_threadpool_size` runs more quickly than other values
153+
# on some tests we tried. Changing `private_threadpool_size` here is achieved
154+
# as a transformation of the dataset with an `options` object.
155+
options = tf.data.Options()
156+
options.threading.private_threadpool_size = private_threadpool_size
157+
study_dataset = study_dataset.with_options(options)
158+
137159
return study_dataset
138160

139161
@staticmethod

0 commit comments

Comments
 (0)