From 138cbf81b55ea83f1f340ad3aa8044bc1473e482 Mon Sep 17 00:00:00 2001 From: Shao Tang Date: Thu, 7 Nov 2024 11:52:35 -0800 Subject: [PATCH] Update group_norm.py --- src/liger_kernel/ops/group_norm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index fab92497..aeb4323f 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -35,8 +35,8 @@ def _group_norm_forward_kernel( RSTD_col_stride, # stride of each column in rstd W_ptr, # pointer to W B_ptr, # pointer to B - hidden_size, # hidden size of X - channels_per_group, # the number of channels per group + hidden_size, # hidden size of X + channels_per_group, # the number of channels per group eps, BLOCK_SIZE: tl.constexpr, ): @@ -280,7 +280,7 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): BLOCK_SIZE=BLOCK_SIZE, dtype=triton_dtype, ) - + # Return tensors in the original shape return DX.view(*shape), DW, DB