Skip to content

Commit 98c06df

Browse files
authored
Merge pull request #146 from nasa/develop
* Support mixed precision * USGS script updates * Minor bugfixes
2 parents dcf7faf + b400ba5 commit 98c06df

File tree

10 files changed

+295
-69
lines changed

10 files changed

+295
-69
lines changed

delta/config/delta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ train:
7979
callbacks: ~
8080
# augmentation functions to apply to training data
8181
augmentations: ~
82+
disable_mixed_precision: false
8283
validation:
8384
# if true, skips the first steps from the training set to use for validation instead
8485
from_training: true

delta/extensions/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def preprocess(self, y_true, y_pred):
199199
nodata = tf.expand_dims(nodata, -1)
200200

201201
# zero all nodata entries
202-
y_pred = y_pred * tf.cast(tf.logical_not(nodata), tf.float32)
202+
y_pred = tf.cast(y_pred, tf.float32) * tf.cast(tf.logical_not(nodata), tf.float32)
203203

204204
true_convert = tf.cast(tf.logical_not(nodata), tf.float32) * true_convert
205205
return (true_convert, y_pred)

delta/extensions/sources/sentinel1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def _prep(self, paths):
178178
assert isinstance(paths, str)
179179
ext = os.path.splitext(paths)[1]
180180

181+
self._name = os.path.splitext(os.path.basename(paths))[0]
182+
181183
tif_path = None
182184
if ext == '.zip': # Need to unpack
183185

delta/extensions/sources/tiff.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,6 @@ def callback_function(output_roi, data, _):
310310
block_x = output_roi.min_x // ts[1]
311311
block_y = output_roi.min_y // ts[0]
312312

313-
print('Write output_roi = ' + str(output_roi))
314-
315313
# Loop on bands
316314
if len(data.shape) == 2:
317315
writer.write_block(data[:, :], block_y, block_x, 0)

delta/imagery/utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,6 @@ def progress_bar(text, fill_amount, prefix = '', length = 80): #pylint: disable=
6969
Number of characters to fill as bar
7070
"""
7171
filled_length = int(length * fill_amount)
72-
fill_char = '█' if sys.stdout.encoding.lower() == 'utf-8' else 'X'
72+
fill_char = '█' if str(sys.stdout.encoding).lower() == 'utf-8' else 'X'
7373
prog_bar = fill_char * filled_length + '-' * (length - filled_length)
7474
print('\r%s |%s| %s' % (prefix, prog_bar, text), end = '\r')

delta/ml/ml_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,15 @@ def __init__(self):
180180
self.register_field('steps', int, None, config.validate_non_negative, 'Batches to train per epoch.')
181181
self.register_field('optimizer', (str, dict), None, None, 'Keras optimizer to use.')
182182
self.register_field('callbacks', list, 'callbacks', None, 'Callbacks used to modify training')
183+
self.register_field('disable_mixed_precision', bool, 'disable_mixed_precision', None,
184+
'Disables mixed precision tensorflow policy. By default DELTA will use mixed '
185+
'precision if the hardware supports it. Details on ways to improve mixed '
186+
'precision performance: '
187+
'https://www.tensorflow.org/guide/mixed_precision#summary')
183188
self.register_arg('epochs', '--epochs')
184189
self.register_arg('batch_size', '--batch-size')
185190
self.register_arg('steps', '--steps')
191+
self.register_arg('disable_mixed_precision', '--disable-mixed-precision', action="store_true", type=None)
186192
self.register_field('augmentations', list, None, None, None)
187193
self.register_component(ValidationConfig(), 'validation')
188194
self.register_component(NetworkConfig(), 'network')

delta/subcommands/train.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,54 @@
2222
import sys
2323
import time
2424

25-
#import logging
26-
#logging.getLogger("tensorflow").setLevel(logging.DEBUG)
25+
# import logging
26+
# logging.getLogger("tensorflow").setLevel(logging.DEBUG)
2727

2828
import tensorflow as tf
29+
from tensorflow.keras import mixed_precision
2930

3031
from delta.config import config
3132
from delta.imagery import imagery_dataset
3233
from delta.ml.train import train
3334
from delta.ml.config_parser import config_model
3435
from delta.ml.io import save_model, load_model
3536

36-
#tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
37+
38+
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
39+
40+
def mixed_policy_device_compatible():
41+
42+
# gpu check logic taken from https://github.com/keras-team/keras/blob/70d7d07bd186b929d81f7a8ceafff5d78d8bd701/keras/mixed_precision/device_compatibility_check.py # pylint: disable=line-too-long
43+
gpus = tf.config.list_physical_devices('GPU')
44+
gpu_details_list = [tf.config.experimental.get_device_details(g) for g in gpus]
45+
46+
supported_device_strs = []
47+
unsupported_device_strs = []
48+
for details in gpu_details_list:
49+
name = details.get('device_name', 'Unknown GPU')
50+
cc = details.get('compute_capability')
51+
if cc:
52+
device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
53+
if cc >= (7, 0):
54+
supported_device_strs.append(device_str)
55+
else:
56+
unsupported_device_strs.append(device_str)
57+
else:
58+
unsupported_device_strs.append(
59+
name + ', no compute capability (probably not an Nvidia GPU)')
60+
61+
if unsupported_device_strs or not supported_device_strs:
62+
return False
63+
# else mixed policy is compatible
64+
return True
65+
3766

3867
def main(options):
68+
if mixed_policy_device_compatible() and not config.train.disable_mixed_precision():
69+
mixed_precision.set_global_policy('mixed_float16')
70+
print('Tensorflow Mixed Precision is enabled. This improves training performance on compatible GPUs. '
71+
'However certain precautions should be taken and several additional changes can be made to improve '
72+
'performance further. Details: https://www.tensorflow.org/guide/mixed_precision#summary')
3973

4074
images = config.dataset.images()
4175
if not images:

0 commit comments

Comments
 (0)