|
9 | 9 |
|
10 | 10 | from .. import runtime
|
11 | 11 | from ..runtime import torch_device_fn
|
12 |
| -from ..runtime.moduel_tool import tl_extra_module |
13 |
| -from ..utils import libentry |
| 12 | +from ..utils import libentry, tl_extra_shim |
14 | 13 | from ..utils.type_utils import get_accumulator_dtype
|
15 | 14 |
|
16 |
| -rsqrt = tl_extra_module.rsqrt |
| 15 | +rsqrt = tl_extra_shim.rsqrt |
17 | 16 |
|
18 | 17 |
|
19 | 18 | def make_3d_for_bn(input: Tensor) -> Tensor:
|
@@ -285,8 +284,12 @@ def batch_norm_backward_kernel(
|
285 | 284 |
|
286 | 285 | if affine:
|
287 | 286 | weight = tl.load(feat_pid + weight_pointer)
|
288 |
| - weight_grad = 0.0 |
289 |
| - bias_grad = 0.0 |
| 287 | + weight_grad_acc = tl.zeros( |
| 288 | + [BLOCK_SIZE_BATCH, BLOCK_SIZE_SPATIAL], dtype=tl.float32 |
| 289 | + ) |
| 290 | + bias_grad_acc = tl.zeros( |
| 291 | + [BLOCK_SIZE_BATCH, BLOCK_SIZE_SPATIAL], dtype=tl.float32 |
| 292 | + ) |
290 | 293 |
|
291 | 294 | else:
|
292 | 295 | weight = 1.0
|
@@ -337,10 +340,12 @@ def batch_norm_backward_kernel(
|
337 | 340 | )
|
338 | 341 |
|
339 | 342 | if affine:
|
340 |
| - weight_grad += tl.sum(curr_pre_lin * curr_output_grad) |
341 |
| - bias_grad += tl.sum(curr_output_grad) |
| 343 | + weight_grad_acc += curr_pre_lin * curr_output_grad |
| 344 | + bias_grad_acc += curr_output_grad |
342 | 345 |
|
343 | 346 | if affine:
|
| 347 | + weight_grad = tl.sum(weight_grad_acc) |
| 348 | + bias_grad = tl.sum(weight_grad_acc) |
344 | 349 | tl.store(feat_pid + weight_grad_pointer, weight_grad)
|
345 | 350 | tl.store(feat_pid + bias_grad_pointer, bias_grad)
|
346 | 351 |
|
@@ -385,9 +390,8 @@ def forward(
|
385 | 390 | running_var = input if running_var is None else running_var
|
386 | 391 |
|
387 | 392 | # Launches 1D grid where each program operates over one feature.
|
388 |
| - grid = lambda _: (feat_dim,) |
389 | 393 | with torch_device_fn.device(input.device):
|
390 |
| - batch_norm_forward_kernel[grid]( |
| 394 | + batch_norm_forward_kernel[(feat_dim,)]( |
391 | 395 | input_3d,
|
392 | 396 | weight,
|
393 | 397 | bias,
|
@@ -431,9 +435,8 @@ def backward(ctx, output_grad):
|
431 | 435 | weight_grad = bias_grad = None
|
432 | 436 |
|
433 | 437 | # Launches 1D grid where each program operates over one feature.
|
434 |
| - grid = lambda _: (feat_dim,) |
435 | 438 | with torch_device_fn.device(input.device):
|
436 |
| - batch_norm_backward_kernel[grid]( |
| 439 | + batch_norm_backward_kernel[(feat_dim,)]( |
437 | 440 | output_grad_3d,
|
438 | 441 | input_3d,
|
439 | 442 | mean,
|
|
0 commit comments