14
14
from .cpp_extensions import dgated_gelu , dgated_gelu_cast_transpose
15
15
from .cpp_extensions import rmsnorm_fwd_fp8 , rmsnorm_bwd
16
16
from .cpp_extensions import layernorm_fwd_fp8 , layernorm_bwd
17
- from .dot import fp8_dot_impl , quantize
17
+ from .dot import fp8_dot_impl , get_precision_of_fp8_dot , quantize
18
18
from .layernorm import canonicalize_layernorm_type
19
19
from .fp8 import FP8Helper , FP8MetaPackage
20
20
@@ -177,7 +177,8 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
177
177
178
178
# (batch..., hidden_in) x (hidden_in, 2, hidden_out)
179
179
dot_1_output = fp8_dot_impl (ln_out , casted_kernel_1 , x_scale_inv , kernel_1_scale_inv , x .dtype ,
180
- (x_contracting_dims , (0 ,)))
180
+ (x_contracting_dims , (0 ,)),
181
+ get_precision_of_fp8_dot (FP8Helper .FP8_2X_ACC_FPROP ))
181
182
182
183
gemm2_x_idx , gemm2_kernel_idx , _ = FP8Helper .get_fp8_meta_indices (1 )
183
184
@@ -198,7 +199,8 @@ def _layernrom_geglu_fp8_mlp_fwd_rule(
198
199
199
200
# (batch..., hidden_in) x (hidden_out, hidden_in)
200
201
dot_2_output = fp8_dot_impl (casted_geglu_out , casted_kernel_2 , geglu_out_scale_inv ,
201
- kernel_2_scale_inv , x .dtype , (x_contracting_dims , (0 ,)))
202
+ kernel_2_scale_inv , x .dtype , (x_contracting_dims , (0 ,)),
203
+ get_precision_of_fp8_dot (FP8Helper .FP8_2X_ACC_FPROP ))
202
204
203
205
ctx = (x , ln_out , mu , rsigma , gamma , dot_1_output , casted_geglu_out , casted_kernel_1 ,
204
206
casted_kernel_2 , fp8_max , amax , scale , scale_inv , updated_x_amax , updated_geglu_amax ,
@@ -237,12 +239,14 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
237
239
# (hidden, batch...,) x (hidden, batch...)
238
240
gemm2_x_scale_inv = scale_inv [gemm2_x_idx ]
239
241
wgrad_2 = fp8_dot_impl (casted_geglu_out_t , casted_grad_t , gemm2_x_scale_inv , grad_scale_inv ,
240
- grad .dtype , (xt_batch_dims , xt_batch_dims ))
242
+ grad .dtype , (xt_batch_dims , xt_batch_dims ),
243
+ get_precision_of_fp8_dot (FP8Helper .FP8_2X_ACC_WGRAD ))
241
244
242
245
# (batch..., hidden_out) x (hidden_in, hidden_out)
243
246
kernel_2_scale_inv = scale_inv [gemm2_kernel_idx ]
244
247
dgrad_2 = fp8_dot_impl (casted_grad , casted_kernel_2 , grad_scale_inv , kernel_2_scale_inv ,
245
- grad .dtype , (x_contracting_dims , (1 ,)))
248
+ grad .dtype , (x_contracting_dims , (1 ,)),
249
+ get_precision_of_fp8_dot (FP8Helper .FP8_2X_ACC_DGRAD ))
246
250
247
251
gemm1_x_idx , gemm1_kernel_idx , gemm1_grad_idx = FP8Helper .get_fp8_meta_indices (0 )
248
252
@@ -265,17 +269,16 @@ def _layernrom_geglu_fp8_mlp_bwd_rule(
265
269
xt_batch_dims_plus_act_dim = tuple (i + 1 for i in xt_batch_dims )
266
270
gemm1_x_scale_inv = scale_inv [gemm1_x_idx ]
267
271
wgrad_1 = fp8_dot_impl (ln_out_t , casted_dgeglu_t , gemm1_x_scale_inv , dgeglu_scale_inv ,
268
- grad .dtype , (xt_batch_dims , xt_batch_dims_plus_act_dim ))
272
+ grad .dtype , (xt_batch_dims , xt_batch_dims_plus_act_dim ),
273
+ get_precision_of_fp8_dot (FP8Helper .FP8_2X_ACC_WGRAD ))
269
274
270
275
# (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out)
271
276
x_contracting_dims_plus_act_dim = (min (x_contracting_dims ),) + tuple (
272
277
i + 1 for i in x_contracting_dims )
273
278
kernel_1_scale_inv = scale_inv [gemm1_kernel_idx ]
274
279
dgrad_1 = fp8_dot_impl (casted_dgeglu , casted_kernel_1 , dgeglu_scale_inv , kernel_1_scale_inv ,
275
- grad .dtype , (x_contracting_dims_plus_act_dim , (
276
- 1 ,
277
- 2 ,
278
- )))
280
+ grad .dtype , (x_contracting_dims_plus_act_dim , (1 , 2 )),
281
+ get_precision_of_fp8_dot (FP8Helper .FP8_2X_ACC_DGRAD ))
279
282
280
283
if layernorm_type == 'layernorm' :
281
284
dx , dgamma , dbeta = layernorm_bwd (dgrad_1 ,
0 commit comments