Skip to content

Commit

Permalink
Add checkpoint_name
Browse files Browse the repository at this point in the history
Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 committed Nov 28, 2023
1 parent 666539f commit 767474c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
5 changes: 5 additions & 0 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name

from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
Expand Down Expand Up @@ -923,6 +924,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)

x = checkpoint_name(x, 'ffn1')

activations = []
if is_geglu(self.activations):
z = geglu(x)
Expand Down Expand Up @@ -957,4 +960,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))

out = checkpoint_name(out, 'ffn2')

return out, ln_output # Output, layner_norm_output
11 changes: 11 additions & 0 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
from jax.ad_checkpoint import checkpoint_name

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
Expand Down Expand Up @@ -211,6 +212,8 @@ def core_attention(query: Array,
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)

attn_weights = checkpoint_name(attn_weights, 'logits')

attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))

Expand Down Expand Up @@ -499,6 +502,7 @@ def _check_head_dim(head_dim):
bias_axes=(W_JOINED_AXES, W_TP_AXES),
name='qkv',
dtype=self.dtype)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
if not use_fused_attn:
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
else:
Expand Down Expand Up @@ -530,6 +534,7 @@ def _check_head_dim(head_dim):
bias_axes=(W_JOINED_AXES, W_TP_AXES),
name='kv',
dtype=self.dtype)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
if not use_fused_attn:
key, value = jnp.split(kv_proj, [1], axis=-2)
else:
Expand Down Expand Up @@ -574,6 +579,9 @@ def _check_head_dim(head_dim):
residual = ln_out

if not use_fused_attn:
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
Expand Down Expand Up @@ -705,6 +713,8 @@ def convert_to_softmax_type(attn_mask_type, mask):
dtype=self.dtype,
float32_logits=self.float32_logits)

x = checkpoint_name(x, 'context')

x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

attn_context_sharding_constraint = \
Expand All @@ -723,6 +733,7 @@ def convert_to_softmax_type(attn_mask_type, mask):
bias_axes=(W_NO_SHARD_AXES,),
dtype=self.dtype,
name='out')(x)
out = checkpoint_name(out, 'out_proj')
return out, residual


Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/jax/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from enum import Enum
from functools import partial
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -91,6 +92,9 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (qkv, softmax_aux, rng_state, output, squeezed_mask)


Expand Down

0 comments on commit 767474c

Please sign in to comment.