Skip to content

Commit 7a0de5b

Browse files
committed
Modifiied code with the review feedback.
Signed-off-by: Ming Huang <[email protected]>
1 parent 5c07f23 commit 7a0de5b

File tree

4 files changed

+39
-35
lines changed

4 files changed

+39
-35
lines changed

transformer_engine/jax/cpp_extensions.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2884,7 +2884,7 @@ def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_
28842884
transpose_axis_boundary=transpose_axis_boundary)
28852885

28862886

2887-
class CastPrimitive(BasePrimitive):
2887+
class CastFP8Primitive(BasePrimitive):
28882888
"""
28892889
Cast Primitive
28902890
"""
@@ -2941,7 +2941,11 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
29412941
jax_dtype_to_te_dtype(x_aval.dtype),
29422942
jax_dtype_to_te_dtype(out_dtype))
29432943

2944-
out = custom_caller(CastPrimitive.name, args, opaque, False, operand_output_aliases={1: 1})
2944+
out = custom_caller(CastFP8Primitive.name,
2945+
args,
2946+
opaque,
2947+
False,
2948+
operand_output_aliases={1: 1})
29452949

29462950
return out
29472951

@@ -2950,23 +2954,23 @@ def impl(x, amax, scale, scale_inv, out_dtype):
29502954
"""
29512955
te_cast implementation
29522956
"""
2953-
assert CastPrimitive.inner_primitive is not None
2957+
assert CastFP8Primitive.inner_primitive is not None
29542958
casted_x, updated_amax = \
2955-
CastPrimitive.inner_primitive.bind(
2959+
CastFP8Primitive.inner_primitive.bind(
29562960
x, amax, scale, scale_inv, out_dtype=out_dtype)
29572961
return casted_x, updated_amax
29582962

29592963
@staticmethod
29602964
def batcher(batched_args, batch_dims, *, out_dtype):
29612965
_check_valid_batch_dims(batch_dims)
2962-
assert CastPrimitive.outer_primitive is not None
2966+
assert CastFP8Primitive.outer_primitive is not None
29632967

29642968
x, amax, scale, scale_inv = batched_args
29652969
x_bdim, amax_bdim, *_ = batch_dims
29662970

29672971
out_bdims = x_bdim, x_bdim, amax_bdim
2968-
return CastPrimitive.outer_primitive.bind(x, amax, scale, scale_inv,
2969-
out_dtype=out_dtype), out_bdims
2972+
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
2973+
out_dtype=out_dtype), out_bdims
29702974

29712975
@staticmethod
29722976
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
@@ -2987,24 +2991,24 @@ def partition(out_dtype, mesh, arg_infos, result_infos):
29872991

29882992
def sharded_impl(x, amax, scale, scale_inv):
29892993
local_cx, local_updated_amax = \
2990-
CastPrimitive.impl(x, amax, scale, scale_inv, out_dtype=out_dtype)
2994+
CastFP8Primitive.impl(x, amax, scale, scale_inv, out_dtype=out_dtype)
29912995
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)
29922996

29932997
return local_cx, global_updated_amax
29942998

29952999
return mesh, sharded_impl, out_shardings, arg_shardings
29963000

29973001

2998-
register_primitive(CastPrimitive)
3002+
register_primitive(CastFP8Primitive)
29993003

30003004

3001-
def cast(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
3002-
out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray]:
3005+
def cast_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
3006+
out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray]:
30033007
"""
30043008
Cast wrapper
30053009
Return FP8 tensor
30063010
"""
3007-
return CastPrimitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
3011+
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
30083012

30093013

30103014
class TransposePrimitive(BasePrimitive):

transformer_engine/jax/dot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ def _fp8_dot_fwd_rule(
113113
kernel_scale_inv = scale_inv[gemm_kernel_idx]
114114
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
115115
# unnecessary copy to break FP8 GEMM pattern matching.
116-
casted_kerenl, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
116+
casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
117117

118-
output = fp8_dot_impl(casted_x, casted_kerenl, x_scale_inv, kernel_scale_inv, x.dtype,
118+
output = fp8_dot_impl(casted_x, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
119119
(lhs_contracting_dims, rhs_contracting_dims))
120120

121-
ctx = (casted_x, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
121+
ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
122122
updated_kernel_amax, x.shape, kernel.shape)
123123
return output, ctx
124124

125125

126126
def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument
127127
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
128128

129-
casted_x, casted_kerenl, fp8_max, amax, scale, scale_inv, \
129+
casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, \
130130
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx
131131

132132
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
@@ -150,7 +150,7 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
150150
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
151151
k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
152152
kernel_scale_inv = scale_inv[gemm_kernel_idx]
153-
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
153+
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
154154
(g_constracting_dim, k_constracting_dim))
155155

156156
amax = amax.at[gemm_x_idx, 0].set(updated_x_amax)

transformer_engine/jax/layernorm.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import jax
88
import jax.numpy as jnp
99

10-
from .cpp_extensions import cast, cast_transpose, transpose
10+
from .cpp_extensions import cast_fp8, cast_transpose, transpose
1111
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
1212
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
1313
from .dot import fp8_dot_impl
@@ -188,14 +188,14 @@ def _layernorm_fp8_dot_fwd_rule(
188188
# Kernel in (hidden_in, hidden_out...)
189189
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
190190
# unnecessary copy to break FP8 GEMM pattern matching.
191-
casted_kerenl, updated_kernel_amax = \
192-
cast(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
191+
casted_kernel, updated_kernel_amax = \
192+
cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
193193

194194
# (batch..., hidden_in) x (hidden_in, hidden_out...)
195-
output = fp8_dot_impl(ln_out, casted_kerenl, x_scale_inv, kernel_scale_inv, x.dtype,
195+
output = fp8_dot_impl(ln_out, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
196196
(x_contracting_dims, k_contracting_dims))
197197

198-
ctx = (ln_out, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
198+
ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
199199
updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
200200
k_contracting_dims)
201201

@@ -210,7 +210,7 @@ def _layernorm_fp8_dot_bwd_rule(
210210
epsilon,
211211
ctx,
212212
grad):
213-
ln_out_, casted_kerenl, fp8_max, amax, scale, scale_inv, \
213+
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
214214
updated_x_amax, updated_kernel_amax, \
215215
x_shape, kernel_shape, mu, rsigma, x, gamma, \
216216
x_contracting_dims, k_contracting_dims = ctx
@@ -237,7 +237,7 @@ def _layernorm_fp8_dot_bwd_rule(
237237
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
238238
k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
239239
kernel_scale_inv = scale_inv[gemm_kernel_idx]
240-
dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
240+
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
241241
(g_for_dgrad_constracting_dim, k_constracting_dim))
242242

243243
if layernorm_type == 'layernorm':

transformer_engine/jax/mlp.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import jax
1010
import jax.numpy as jnp
1111

12-
from .cpp_extensions import cast, transpose, cast_transpose
12+
from .cpp_extensions import cast_fp8, transpose, cast_transpose
1313
from .cpp_extensions import gated_gelu, gated_gelu_fp8
1414
from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose
1515
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
@@ -172,11 +172,11 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
172172

173173
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
174174
# unnecessary copy to break FP8 GEMM pattern matching.
175-
casted_kerenl_1, updated_kernel_1_amax = \
176-
cast(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
175+
casted_kernel_1, updated_kernel_1_amax = \
176+
cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
177177

178178
# (batch..., hidden_in) x (hidden_in, 2, hidden_out)
179-
dot_1_output = fp8_dot_impl(ln_out, casted_kerenl_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
179+
dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype,
180180
(x_contracting_dims, (0,)))
181181

182182
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
@@ -194,14 +194,14 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
194194
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
195195
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
196196
# unnecessary copy to break FP8 GEMM pattern matching.
197-
casted_kerenl_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
197+
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
198198

199199
# (batch..., hidden_in) x (hidden_out, hidden_in)
200-
dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kerenl_2, geglu_out_scale_inv,
200+
dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv,
201201
kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)))
202202

203-
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kerenl_1,
204-
casted_kerenl_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
203+
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1,
204+
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax,
205205
updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims)
206206

207207
return dot_2_output, ctx
@@ -216,7 +216,7 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
216216
ctx,
217217
grad):
218218
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \
219-
casted_kerenl_1, casted_kerenl_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
219+
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
220220
updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
221221
x_contracting_dims, xt_batch_dims = ctx
222222

@@ -241,7 +241,7 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
241241

242242
# (batch..., hidden_out) x (hidden_in, hidden_out)
243243
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
244-
dgrad_2 = fp8_dot_impl(casted_grad, casted_kerenl_2, grad_scale_inv, kernel_2_scale_inv,
244+
dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
245245
grad.dtype, (x_contracting_dims, (1,)))
246246

247247
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
@@ -271,7 +271,7 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
271271
x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple(
272272
i + 1 for i in x_contracting_dims)
273273
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
274-
dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kerenl_1, dgeglu_scale_inv, kernel_1_scale_inv,
274+
dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kernel_1, dgeglu_scale_inv, kernel_1_scale_inv,
275275
grad.dtype, (x_contracting_dims_plus_act_dim, (
276276
1,
277277
2,

0 commit comments

Comments
 (0)