Skip to content

Commit

Permalink
Update group_norm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts authored Nov 7, 2024
1 parent 5ff64c5 commit 138cbf8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/liger_kernel/ops/group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def _group_norm_forward_kernel(
RSTD_col_stride, # stride of each column in rstd
W_ptr, # pointer to W
B_ptr, # pointer to B
hidden_size, # hidden size of X
channels_per_group, # the number of channels per group
hidden_size, # hidden size of X
channels_per_group, # the number of channels per group
eps,
BLOCK_SIZE: tl.constexpr,
):
Expand Down Expand Up @@ -280,7 +280,7 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
BLOCK_SIZE=BLOCK_SIZE,
dtype=triton_dtype,
)

# Return tensors in the original shape
return DX.view(*shape), DW, DB

Expand Down

0 comments on commit 138cbf8

Please sign in to comment.