Skip to content

Commit 344c643

Browse files
committed
Resolved a portion of the review comments
1 parent 7589e05 commit 344c643

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

src/flag_gems/ops/batch_norm.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99

1010
from .. import runtime
1111
from ..runtime import torch_device_fn
12-
from ..runtime.moduel_tool import tl_extra_module
13-
from ..utils import libentry
12+
from ..utils import libentry, tl_extra_shim
1413
from ..utils.type_utils import get_accumulator_dtype
1514

16-
rsqrt = tl_extra_module.rsqrt
15+
rsqrt = tl_extra_shim.rsqrt
1716

1817

1918
def make_3d_for_bn(input: Tensor) -> Tensor:
@@ -285,8 +284,12 @@ def batch_norm_backward_kernel(
285284

286285
if affine:
287286
weight = tl.load(feat_pid + weight_pointer)
288-
weight_grad = 0.0
289-
bias_grad = 0.0
287+
weight_grad_acc = tl.zeros(
288+
[BLOCK_SIZE_BATCH, BLOCK_SIZE_SPATIAL], dtype=tl.float32
289+
)
290+
bias_grad_acc = tl.zeros(
291+
[BLOCK_SIZE_BATCH, BLOCK_SIZE_SPATIAL], dtype=tl.float32
292+
)
290293

291294
else:
292295
weight = 1.0
@@ -337,10 +340,12 @@ def batch_norm_backward_kernel(
337340
)
338341

339342
if affine:
340-
weight_grad += tl.sum(curr_pre_lin * curr_output_grad)
341-
bias_grad += tl.sum(curr_output_grad)
343+
weight_grad_acc += curr_pre_lin * curr_output_grad
344+
bias_grad_acc += curr_output_grad
342345

343346
if affine:
347+
weight_grad = tl.sum(weight_grad_acc)
348+
bias_grad = tl.sum(weight_grad_acc)
344349
tl.store(feat_pid + weight_grad_pointer, weight_grad)
345350
tl.store(feat_pid + bias_grad_pointer, bias_grad)
346351

@@ -385,9 +390,8 @@ def forward(
385390
running_var = input if running_var is None else running_var
386391

387392
# Launches 1D grid where each program operates over one feature.
388-
grid = lambda _: (feat_dim,)
389393
with torch_device_fn.device(input.device):
390-
batch_norm_forward_kernel[grid](
394+
batch_norm_forward_kernel[(feat_dim,)](
391395
input_3d,
392396
weight,
393397
bias,
@@ -431,9 +435,8 @@ def backward(ctx, output_grad):
431435
weight_grad = bias_grad = None
432436

433437
# Launches 1D grid where each program operates over one feature.
434-
grid = lambda _: (feat_dim,)
435438
with torch_device_fn.device(input.device):
436-
batch_norm_backward_kernel[grid](
439+
batch_norm_backward_kernel[(feat_dim,)](
437440
output_grad_3d,
438441
input_3d,
439442
mean,

0 commit comments

Comments
 (0)