@@ -31,10 +31,21 @@ def __init__(self):
31
31
"deterministic" : False ,
32
32
}
33
33
34
- def __call__ (self , study_description , num_workers = 1 , worker_index = 0 ):
34
+ def __call__ (
35
+ self ,
36
+ study_description ,
37
+ num_workers = None ,
38
+ worker_index = None ,
39
+ private_threadpool_size = None ,
40
+ ):
35
41
"""
36
42
From scratch, creates a tensorflow dataset with one tensorflow element per tile
37
43
"""
44
+ num_workers = num_workers if num_workers is not None else 1
45
+ worker_index = worker_index if worker_index is not None else 0
46
+ private_threadpool_size = (
47
+ private_threadpool_size if private_threadpool_size is not None else 1
48
+ )
38
49
39
50
# Call to superclass to find the locations for the chunks
40
51
# print(f"Build chunks: begin {datetime.datetime.now()}")
@@ -134,6 +145,17 @@ def __call__(self, study_description, num_workers=1, worker_index=0):
134
145
lambda elem : (elem , None , None ), ** self .dataset_map_options
135
146
)
136
147
# print(f"Build study_dataset pop: end {datetime.datetime.now()}")
148
+
149
+ # By default `private_threadpool_size` is set to 0, which means that Tensorflow
150
+ # is free to choose the number without limit. However, Tensorflow can grind to
151
+ # a halt when processing a large dataset with this default behavior on GPU. A
152
+ # value of 1 for `private_threadpool_size` runs more quickly than other values
153
+ # on some tests we tried. Changing `private_threadpool_size` here is achieved
154
+ # as a transformation of the dataset with an `options` object.
155
+ options = tf .data .Options ()
156
+ options .threading .private_threadpool_size = private_threadpool_size
157
+ study_dataset = study_dataset .with_options (options )
158
+
137
159
return study_dataset
138
160
139
161
@staticmethod
0 commit comments