Skip to content

Commit ca3f0a5

Browse files
committed
Adding 2xACC control to FP8 GEMMs.
Signed-off-by: Ming Huang <[email protected]>
1 parent 7a0de5b commit ca3f0a5

File tree

3 files changed

+38
-19
lines changed

3 files changed

+38
-19
lines changed

transformer_engine/jax/dot.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from .cpp_extensions import cast_transpose
1212
from .fp8 import FP8Helper, FP8MetaPackage
1313

14+
Precision = jax.lax.Precision
15+
1416

1517
def type_safe_dot_general(
1618
x,
@@ -62,7 +64,8 @@ def fp8_dot_impl(
6264
lhs_scale_inv: jnp.ndarray,
6365
rhs_scale_inv: jnp.ndarray,
6466
ctype: jnp.dtype, # computing type
65-
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
67+
contracting_dims: Tuple[Sequence[int], Sequence[int]],
68+
precision: Precision = None):
6669
"""
6770
FP8 GEMM for XLA pattern match
6871
"""
@@ -71,7 +74,14 @@ def fp8_dot_impl(
7174
lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
7275
rhs = dequantize(q_rhs, ctype, rhs_scale_inv)
7376

74-
return jax.lax.dot_general(lhs, rhs, dim_nums)
77+
return jax.lax.dot_general(lhs, rhs, dim_nums, precision=precision)
78+
79+
80+
def get_precision_of_fp8_dot(enable_2xACC: bool):
81+
"""
82+
Get Precision of FP8 DOT.
83+
"""
84+
return jax.lax.Precision.HIGHEST if enable_2xACC else jax.lax.Precision.DEFAULT
7585

7686

7787
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
@@ -116,7 +126,8 @@ def _fp8_dot_fwd_rule(
116126
casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
117127

118128
output = fp8_dot_impl(casted_x, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
119-
(lhs_contracting_dims, rhs_contracting_dims))
129+
(lhs_contracting_dims, rhs_contracting_dims),
130+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
120131

121132
ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
122133
updated_kernel_amax, x.shape, kernel.shape)
@@ -144,14 +155,16 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
144155
gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim))
145156
x_scale_inv = scale_inv[gemm_x_idx]
146157
wgrad = fp8_dot_impl(casted_x, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
147-
(x_constracting_dim, gt_constracting_dim))
158+
(x_constracting_dim, gt_constracting_dim),
159+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
148160

149161
g_constracting_dim = tuple(
150162
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
151163
k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
152164
kernel_scale_inv = scale_inv[gemm_kernel_idx]
153165
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
154-
(g_constracting_dim, k_constracting_dim))
166+
(g_constracting_dim, k_constracting_dim),
167+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
155168

156169
amax = amax.at[gemm_x_idx, 0].set(updated_x_amax)
157170
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax)

transformer_engine/jax/layernorm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .cpp_extensions import cast_fp8, cast_transpose, transpose
1111
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
1212
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
13-
from .dot import fp8_dot_impl
13+
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
1414
from .fp8 import FP8Helper, FP8MetaPackage
1515

1616

@@ -193,7 +193,8 @@ def _layernorm_fp8_dot_fwd_rule(
193193

194194
# (batch..., hidden_in) x (hidden_in, hidden_out...)
195195
output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
196-
(x_contracting_dims, k_contracting_dims))
196+
(x_contracting_dims, k_contracting_dims),
197+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
197198

198199
ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
199200
updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
@@ -231,14 +232,16 @@ def _layernorm_fp8_dot_bwd_rule(
231232
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
232233
x_scale_inv = scale_inv[gemm_x_idx]
233234
wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
234-
(xt_constracting_dim, gt_constracting_dim))
235+
(xt_constracting_dim, gt_constracting_dim),
236+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
235237

236238
g_for_dgrad_constracting_dim = tuple(
237239
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
238240
k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
239241
kernel_scale_inv = scale_inv[gemm_kernel_idx]
240242
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
241-
(g_for_dgrad_constracting_dim, k_constracting_dim))
243+
(g_for_dgrad_constracting_dim, k_constracting_dim),
244+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
242245

243246
if layernorm_type == 'layernorm':
244247
dx, dgamma, dbeta = layernorm_bwd(dgrad,

transformer_engine/jax/mlp.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
1515
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
1616
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
17-
from .dot import fp8_dot_impl, quantize
17+
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
1818
from .layernorm import canonicalize_layernorm_type
1919
from .fp8 import FP8Helper, FP8MetaPackage
2020

@@ -177,7 +177,8 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
177177

178178
# (batch..., hidden_in) x (hidden_in, 2, hidden_out)
179179
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
180-
(x_contracting_dims, (0,)))
180+
(x_contracting_dims, (0,)),
181+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
181182

182183
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
183184

@@ -198,7 +199,8 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
198199

199200
# (batch..., hidden_in) x (hidden_out, hidden_in)
200201
dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv,
201-
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)))
202+
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)),
203+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
202204

203205
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1,
204206
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
@@ -237,12 +239,14 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
237239
# (hidden, batch...,) x (hidden, batch...)
238240
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
239241
wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv,
240-
grad.dtype, (xt_batch_dims, xt_batch_dims))
242+
grad.dtype, (xt_batch_dims, xt_batch_dims),
243+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
241244

242245
# (batch..., hidden_out) x (hidden_in, hidden_out)
243246
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
244247
dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
245-
grad.dtype, (x_contracting_dims, (1,)))
248+
grad.dtype, (x_contracting_dims, (1,)),
249+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
246250

247251
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
248252

@@ -265,17 +269,16 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
265269
xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims)
266270
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
267271
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv,
268-
grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim))
272+
grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim),
273+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
269274

270275
# (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out)
271276
x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple(
272277
i + 1 for i in x_contracting_dims)
273278
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
274279
dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kernel_1, dgeglu_scale_inv, kernel_1_scale_inv,
275-
grad.dtype, (x_contracting_dims_plus_act_dim, (
276-
1,
277-
2,
278-
)))
280+
grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)),
281+
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
279282

280283
if layernorm_type == 'layernorm':
281284
dx, dgamma, dbeta = layernorm_bwd(dgrad_1,

0 commit comments

Comments
 (0)