From 767474c36f1c57c9c059c70af9ec914e0330aaad Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 23 Nov 2023 22:29:13 -0800 Subject: [PATCH] Add checkpoint_name Signed-off-by: Reese Wang --- transformer_engine/jax/flax/module.py | 5 +++++ transformer_engine/jax/flax/transformer.py | 11 +++++++++++ transformer_engine/jax/fused_attn.py | 4 ++++ 3 files changed, 20 insertions(+) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 7d80be5878..db1b7bedf2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -16,6 +16,7 @@ from jax import lax from jax import nn as jax_nn from jax import random as jax_random +from jax.ad_checkpoint import checkpoint_name from ..dot import type_safe_dot_general from ..fp8 import FP8Helper, FP8MetaPackage @@ -923,6 +924,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape x += jnp.reshape(bias, bias_shape) + x = checkpoint_name(x, 'ffn1') + activations = [] if is_geglu(self.activations): z = geglu(x) @@ -957,4 +960,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias = bias.astype(self.dtype) out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) + out = checkpoint_name(out, 'ffn2') + return out, ln_output # Output, layner_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 989b060696..df2bfafe21 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -20,6 +20,7 @@ from jax import nn as jax_nn from jax import random as jax_random from jax import lax, vmap +from jax.ad_checkpoint import checkpoint_name from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import LayerNorm, Softmax @@ -211,6 +212,8 @@ def core_attention(query: Array, else: attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) + attn_weights = checkpoint_name(attn_weights, 'logits') + attn_weights = _with_sharding_constraint(attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) @@ -499,6 +502,7 @@ def _check_head_dim(head_dim): bias_axes=(W_JOINED_AXES, W_TP_AXES), name='qkv', dtype=self.dtype)(inputs_q) + qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj') if not use_fused_attn: query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2) else: @@ -530,6 +534,7 @@ def _check_head_dim(head_dim): bias_axes=(W_JOINED_AXES, W_TP_AXES), name='kv', dtype=self.dtype)(inputs_kv) + kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj') if not use_fused_attn: key, value = jnp.split(kv_proj, [1], axis=-2) else: @@ -574,6 +579,9 @@ def _check_head_dim(head_dim): residual = ln_out if not use_fused_attn: + query = checkpoint_name(query, 'query_proj') + key = checkpoint_name(key, 'key_proj') + value = checkpoint_name(value, 'value_proj') query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim)) key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim)) value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim)) @@ -705,6 +713,8 @@ def convert_to_softmax_type(attn_mask_type, mask): dtype=self.dtype, float32_logits=self.float32_logits) + x = checkpoint_name(x, 'context') + x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) attn_context_sharding_constraint = \ @@ -723,6 +733,7 @@ def convert_to_softmax_type(attn_mask_type, mask): bias_axes=(W_NO_SHARD_AXES,), dtype=self.dtype, name='out')(x) + out = checkpoint_name(out, 'out_proj') return out, residual diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index 43e8c3dfc1..f2c74e77e4 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -5,6 +5,7 @@ from enum import Enum from functools import partial +from jax.ad_checkpoint import checkpoint_name import jax import jax.numpy as jnp @@ -91,6 +92,9 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) + output = checkpoint_name(output, 'context') + softmax_aux = checkpoint_name(softmax_aux, 'context') + rng_state = checkpoint_name(rng_state, 'context') return output, (qkv, softmax_aux, rng_state, output, squeezed_mask)