From f6e661878f189f4d4e070379d5c3bc3c1bc8f5f3 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 13 Feb 2023 12:00:58 -0800 Subject: [PATCH 1/5] refactor learners.py to perform gradient clipping on global batch when using ShardedStaticAccumulator --- paxml/learners.py | 177 +++++++++++++++++++++++++++-------------- paxml/learners_test.py | 5 +- 2 files changed, 120 insertions(+), 62 deletions(-) diff --git a/paxml/learners.py b/paxml/learners.py index 868a8fb54..2b5d04450 100644 --- a/paxml/learners.py +++ b/paxml/learners.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 Google LLC. +# Copyright 2022 The Pax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -179,9 +179,43 @@ def get_grad_tx( self._hparams.repeat_prefix_sep, ) + def get_individual_grad_norms( + self, + raw_grads, + optimizer_name): + p = self._hparams + # Compute gradient norm. + + if p.grad_norm_individual_vars: + grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads) + var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms) + + def add_grad_norm_summary(key, value): + base_layer.add_global_summary( + f'per_var_grad_norm/{optimizer_name}{key}', + value, + SummaryType.AGGREGATE_SCALAR, + ) + + jax.tree_map(add_grad_norm_summary, var_keys, grad_norms) + + def keep_step( + self, + grad_norm): + p = self._hparams + keep_threshold = p.skip_step_gradient_norm_value + if keep_threshold: + return jnp.logical_and( + jnp.all(jnp.isfinite(grad_norm)), + jnp.all(jnp.less(grad_norm, keep_threshold)), + ) + else: + return jnp.all(jnp.isfinite(grad_norm)) + def scale_gradients( self, raw_grads: NestedMap, + raw_grad_norm: JTensor, optimizer_name: Optional[str] = None, clip_gradient_norm_to_value: Optional[float] = None, clip_gradient_single_norm_to_value: Optional[float] = None, @@ -203,57 +237,20 @@ def scale_gradients( have anomaly detected (e.g. Nan or Inf, or excessively big gradient norm) and should not be skipped. """ + p = self._hparams + if optimizer_name is None: optimizer_name = '' else: optimizer_name = optimizer_name + '/' + if clip_gradient_norm_to_value is None: clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value if clip_gradient_single_norm_to_value is None: clip_gradient_single_norm_to_value = ( p.optimizer.clip_gradient_single_norm_to_value ) - # Compute gradient norm. - - if p.grad_norm_individual_vars: - grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads) - var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms) - - def add_grad_norm_summary(key, value): - base_layer.add_global_summary( - f'per_var_grad_norm/{optimizer_name}{key}', - value, - SummaryType.AGGREGATE_SCALAR, - ) - - jax.tree_map(add_grad_norm_summary, var_keys, grad_norms) - - if ( - p.grad_norm_summary - or p.check_valid_step - or clip_gradient_norm_to_value - or clip_gradient_single_norm_to_value - ): - raw_grad_norm = _compute_grad_norm(raw_grads) - if p.grad_norm_summary: - base_layer.add_global_summary( - 'learning/' + optimizer_name + 'raw_grad_norm', - raw_grad_norm, - SummaryType.AGGREGATE_SCALAR, - ) - else: - raw_grad_norm = None - - def keep_step(grad_norm): - keep_threshold = p.skip_step_gradient_norm_value - if keep_threshold: - return jnp.logical_and( - jnp.all(jnp.isfinite(grad_norm)), - jnp.all(jnp.less(grad_norm, keep_threshold)), - ) - else: - return jnp.all(jnp.isfinite(grad_norm)) def clip_grads(grads, grad_norm): if clip_gradient_norm_to_value: @@ -282,17 +279,6 @@ def scale_gradient(grad, norm): grad_scale = jnp.array(1.0) return grads, grad_scale - if p.check_valid_step: - # Mark the step as invalid if any gradient anomaly is detected (e.g. Nan - # or Inf, or excessively big gradient norm). - valid_step = keep_step(raw_grad_norm) - base_layer.add_global_summary( - 'learning/' + optimizer_name + 'is_valid_step', - valid_step.astype(jnp.float32), - SummaryType.AGGREGATE_SCALAR, - ) - else: - valid_step = True grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) base_layer.add_global_summary( 'learning/' + optimizer_name + 'grad_scale', @@ -307,7 +293,55 @@ def scale_gradient(grad, norm): clipped_grad_norm, SummaryType.AGGREGATE_SCALAR, ) - return grads, valid_step + return grads + + + def get_grad_norm_valid_step( + self, + raw_grads, + optimizer_name: Optional[str] = None, + clip_gradient_norm_to_value: Optional[float] = None, + clip_gradient_single_norm_to_value: Optional[float] = None + ) -> Tuple[JTensor, JTensor]: + + p = self._hparams + + if optimizer_name is None: + optimizer_name = '' + else: + optimizer_name = optimizer_name + '/' + self.get_individual_grad_norms(raw_grads, + optimizer_name) + + if ( + p.grad_norm_summary + or p.check_valid_step + or clip_gradient_norm_to_value + or clip_gradient_single_norm_to_value + ): + raw_grad_norm = _compute_grad_norm(raw_grads) + if p.grad_norm_summary: + base_layer.add_global_summary( + 'learning/' + optimizer_name + 'raw_grad_norm', + raw_grad_norm, + SummaryType.AGGREGATE_SCALAR, + ) + else: + raw_grad_norm = None + + if p.check_valid_step: + # Mark the step as invalid if any gradient anomaly is detected (e.g. Nan + # or Inf, or excessively big gradient norm). + valid_step = self.keep_step(raw_grad_norm) + base_layer.add_global_summary( + 'learning/' + optimizer_name + 'is_valid_step', + valid_step.astype(jnp.float32), + SummaryType.AGGREGATE_SCALAR, + ) + else: + valid_step = True + + return raw_grad_norm, valid_step def update_states( self, @@ -328,12 +362,20 @@ def update_states( transformed_grad, new_states pair. """ p = self._hparams - - grads, valid_step = self.scale_gradients(grads) + + grad_norm, valid_step = self.get_grad_norm_valid_step(grads) + + using_ga = hasattr(p.optimizer, 'num_sub_batches') + + # When using gradient accumulation, gradient scaling happens within base + # optimizer update + if not using_ga: + grads = self.scale_gradients(grads, grad_norm) + transformed_grad, new_states = self.get_grad_tx(var_weight_hparams).update( grads, states, old_vars ) - + if p.enable_skip_step_on_gradient_anomalies: # Set grads to 0 if the step is invalid. transformed_grad = jax.tree_map( @@ -351,6 +393,7 @@ def _update(updated, original): new_states = jax.tree_map( _update, new_states, states, is_leaf=py_utils.is_optax_masked_node ) + # Final applied grad norm. if p.grad_norm_summary: applied_grad_norm = _compute_grad_norm(transformed_grad) @@ -581,9 +624,17 @@ def scale_gradients_by_optimizer( self, raw_grads: NestedMap, var_weight_hparams: NestedWeightHParams ) -> Tuple[NestedMap, JTensor]: optimizer_mask, default_mask = self.get_masks(var_weight_hparams) + + raw_grads = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask) - all_grads, all_valid_step = self.scale_gradients( - jax.tree_map(lambda x, y: x * y, raw_grads, default_mask), + grad_norm, all_valid_step = self.get_grad_norm_valid_step( + raw_grads, + optimizer_name='main', + ) + + all_grads = self.scale_gradients( + raw_grads, + grad_norm, optimizer_name='main', ) @@ -594,8 +645,13 @@ def scale_gradients_by_optimizer( ): assert optimizer.clip_gradient_norm_to_value is not None assert optimizer.clip_gradient_single_norm_to_value is not None - grads, valid_step = self.scale_gradients( + grad_norm, valid_step = self.get_grad_norm_valid_step( + raw_grads, + optimizer_name=name, + ) + grads = self.scale_gradients( jax.tree_map(lambda x, y: x * y, raw_grads, mask), + grad_norm, optimizer_name=name, clip_gradient_norm_to_value=optimizer.clip_gradient_norm_to_value, clip_gradient_single_norm_to_value=optimizer.clip_gradient_single_norm_to_value, @@ -627,7 +683,8 @@ def update_states( grads, var_weight_hparams ) else: - grads, valid_step = self.scale_gradients(grads) + grad_norm, valid_step = self.get_grad_norm_valid_step(grads) + grads = self.scale_gradients(grads, grad_norm) grad_tx = self.get_grad_tx(var_weight_hparams) transformed_grad, new_states = grad_tx.update(grads, states, old_vars) if self._hparams.enable_skip_step_on_gradient_anomalies: diff --git a/paxml/learners_test.py b/paxml/learners_test.py index e58fe96d7..45b6d50c1 100644 --- a/paxml/learners_test.py +++ b/paxml/learners_test.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 Google LLC. +# Copyright 2022 The Pax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -64,7 +64,8 @@ def test_learner_clip_gradients(self, g1a, g1b, g2, global_clip_norm, grad2=jnp.array([g2], dtype=jnp.float32)) with base_layer.JaxContext.new_context(): - transformed_grads, _ = learner_instance.scale_gradients(grads) + grad_norm, valid_step = learner_instance.get_grad_norm_valid_step(grads) + transformed_grads = learner_instance.scale_gradients(grads, grad_norm) global_norm = np.linalg.norm([g1a, g1b, g2]) local_norm1 = np.linalg.norm([g1a, g1b]) From f57f77af2f8a6f202c81a52abc3c745f02c67c04 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 13 Feb 2023 12:01:30 -0800 Subject: [PATCH 2/5] add AUTHORS file --- AUTHORS | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 AUTHORS diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 000000000..f0ebbcd0c --- /dev/null +++ b/AUTHORS @@ -0,0 +1,8 @@ +# This is the list of Pax's significant contributors. +# +# This does not necessarily list everyone who has contributed code, +# especially since many employees of one corporation may be contributing. +# To see the full list of contributors, see the revision history in +# source control. +Google LLC +NVIDIA Corporation From b69771e377828c18335cecb4a6b23d288ad0b8c4 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 23 Feb 2023 12:03:49 -0800 Subject: [PATCH 3/5] remove AUTHORS file --- AUTHORS | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 AUTHORS diff --git a/AUTHORS b/AUTHORS deleted file mode 100644 index f0ebbcd0c..000000000 --- a/AUTHORS +++ /dev/null @@ -1,8 +0,0 @@ -# This is the list of Pax's significant contributors. -# -# This does not necessarily list everyone who has contributed code, -# especially since many employees of one corporation may be contributing. -# To see the full list of contributors, see the revision history in -# source control. -Google LLC -NVIDIA Corporation From 57e567cd2a1a90982e8bc82762cf3ba07d78c464 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 6 Mar 2023 13:14:24 -0800 Subject: [PATCH 4/5] fix formatting --- paxml/learners.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/paxml/learners.py b/paxml/learners.py index 51e6bce88..aab3c0e7d 100644 --- a/paxml/learners.py +++ b/paxml/learners.py @@ -206,7 +206,7 @@ def add_grad_norm_summary(key, value): jax.tree_map(add_grad_norm_summary, var_keys, grad_norms) def keep_step( - self, + self, grad_norm): p = self._hparams keep_threshold = p.skip_step_gradient_norm_value @@ -243,14 +243,14 @@ def scale_gradients( have anomaly detected (e.g. Nan or Inf, or excessively big gradient norm) and should not be skipped. """ - + p = self._hparams - + if optimizer_name is None: optimizer_name = '' else: optimizer_name = optimizer_name + '/' - + if clip_gradient_norm_to_value is None: clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value if clip_gradient_single_norm_to_value is None: @@ -301,7 +301,7 @@ def scale_gradient(grad, norm): ) return grads # pytype: disable=bad-return-type # jax-ndarray - + def get_grad_norm_valid_step( self, raw_grads, @@ -309,16 +309,16 @@ def get_grad_norm_valid_step( clip_gradient_norm_to_value: Optional[float] = None, clip_gradient_single_norm_to_value: Optional[float] = None ) -> Tuple[JTensor, JTensor]: - + p = self._hparams - + if optimizer_name is None: optimizer_name = '' else: optimizer_name = optimizer_name + '/' self.get_individual_grad_norms(raw_grads, optimizer_name) - + if ( p.grad_norm_summary or p.check_valid_step @@ -334,7 +334,7 @@ def get_grad_norm_valid_step( ) else: raw_grad_norm = None - + if p.check_valid_step: # Mark the step as invalid if any gradient anomaly is detected (e.g. Nan # or Inf, or excessively big gradient norm). @@ -346,7 +346,7 @@ def get_grad_norm_valid_step( ) else: valid_step = True - + return raw_grad_norm, valid_step def update_states( @@ -368,20 +368,20 @@ def update_states( transformed_grad, new_states pair. """ p = self._hparams - + grad_norm, valid_step = self.get_grad_norm_valid_step(grads) - - using_ga = hasattr(p.optimizer, 'num_sub_batches') - + + using_ga = hasattr(p.optimizer, 'num_sub_batches') + # When using gradient accumulation, gradient scaling happens within base - # optimizer update + # optimizer update if not using_ga: grads = self.scale_gradients(grads, grad_norm) - + transformed_grad, new_states = self.get_grad_tx(var_weight_hparams).update( grads, states, old_vars ) - + if p.enable_skip_step_on_gradient_anomalies: # Set grads to 0 if the step is invalid. transformed_grad = jax.tree_map( @@ -399,7 +399,7 @@ def _update(updated, original): new_states = jax.tree_map( _update, new_states, states, is_leaf=py_utils.is_optax_masked_node ) - + # Final applied grad norm. if p.grad_norm_summary: applied_grad_norm = _compute_grad_norm(transformed_grad) @@ -630,17 +630,17 @@ def scale_gradients_by_optimizer( self, raw_grads: NestedMap, var_weight_hparams: NestedWeightHParams ) -> Tuple[NestedMap, JTensor]: optimizer_mask, default_mask = self.get_masks(var_weight_hparams) - + raw_grads = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask) grad_norm, all_valid_step = self.get_grad_norm_valid_step( raw_grads, optimizer_name='main', ) - + all_grads = self.scale_gradients( raw_grads, - grad_norm, + grad_norm, optimizer_name='main', ) From 2303ad84111a61a4dd0e8e199264d387c28fea33 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Sat, 18 Mar 2023 15:53:27 -0700 Subject: [PATCH 5/5] address PR comments --- paxml/learners.py | 51 ++++++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/paxml/learners.py b/paxml/learners.py index aab3c0e7d..e3e153905 100644 --- a/paxml/learners.py +++ b/paxml/learners.py @@ -205,19 +205,6 @@ def add_grad_norm_summary(key, value): jax.tree_map(add_grad_norm_summary, var_keys, grad_norms) - def keep_step( - self, - grad_norm): - p = self._hparams - keep_threshold = p.skip_step_gradient_norm_value - if keep_threshold: - return jnp.logical_and( - jnp.all(jnp.isfinite(grad_norm)), - jnp.all(jnp.less(grad_norm, keep_threshold)), - ) - else: - return jnp.all(jnp.isfinite(grad_norm)) - def scale_gradients( self, raw_grads: NestedMap, @@ -316,8 +303,14 @@ def get_grad_norm_valid_step( optimizer_name = '' else: optimizer_name = optimizer_name + '/' - self.get_individual_grad_norms(raw_grads, - optimizer_name) + self.get_individual_grad_norms(raw_grads, optimizer_name) + + if clip_gradient_norm_to_value is None: + clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value + if clip_gradient_single_norm_to_value is None: + clip_gradient_single_norm_to_value = ( + p.optimizer.clip_gradient_single_norm_to_value + ) if ( p.grad_norm_summary @@ -335,10 +328,20 @@ def get_grad_norm_valid_step( else: raw_grad_norm = None + def keep_step(grad_norm): + keep_threshold = p.skip_step_gradient_norm_value + if keep_threshold: + return jnp.logical_and( + jnp.all(jnp.isfinite(grad_norm)), + jnp.all(jnp.less(grad_norm, keep_threshold)), + ) + else: + return jnp.all(jnp.isfinite(grad_norm)) + if p.check_valid_step: # Mark the step as invalid if any gradient anomaly is detected (e.g. Nan # or Inf, or excessively big gradient norm). - valid_step = self.keep_step(raw_grad_norm) + valid_step = keep_step(raw_grad_norm) base_layer.add_global_summary( 'learning/' + optimizer_name + 'is_valid_step', valid_step.astype(jnp.float32), @@ -371,11 +374,11 @@ def update_states( grad_norm, valid_step = self.get_grad_norm_valid_step(grads) - using_ga = hasattr(p.optimizer, 'num_sub_batches') + using_grad_accum = hasattr(p.optimizer, 'num_sub_batches') # When using gradient accumulation, gradient scaling happens within base # optimizer update - if not using_ga: + if not using_grad_accum: grads = self.scale_gradients(grads, grad_norm) transformed_grad, new_states = self.get_grad_tx(var_weight_hparams).update( @@ -631,15 +634,15 @@ def scale_gradients_by_optimizer( ) -> Tuple[NestedMap, JTensor]: optimizer_mask, default_mask = self.get_masks(var_weight_hparams) - raw_grads = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask) + grads_after_default_mask = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask) grad_norm, all_valid_step = self.get_grad_norm_valid_step( - raw_grads, + grads_after_default_mask, optimizer_name='main', ) all_grads = self.scale_gradients( - raw_grads, + grads_after_default_mask, grad_norm, optimizer_name='main', ) @@ -651,12 +654,14 @@ def scale_gradients_by_optimizer( ): assert optimizer.clip_gradient_norm_to_value is not None assert optimizer.clip_gradient_single_norm_to_value is not None + + grads_after_mask = jax.tree_map(lambda x, y: x * y, raw_grads, mask) grad_norm, valid_step = self.get_grad_norm_valid_step( - raw_grads, + grads_after_mask, optimizer_name=name, ) grads = self.scale_gradients( - jax.tree_map(lambda x, y: x * y, raw_grads, mask), + grads_after_mask, grad_norm, optimizer_name=name, clip_gradient_norm_to_value=optimizer.clip_gradient_norm_to_value,