-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add batch_norm #362
Conversation
src/flag_gems/ops/batch_norm.py
Outdated
@triton.heuristics( | ||
{ | ||
"BLOCK_SIZE_BATCH": lambda args: next_power_of_2(args["batch_dim"]), | ||
"BLOCK_SIZE_SPATIAL": BLOCK_SIZE_SPATIAL_heuristic, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are working on changing one-tile algorithm into loop tiling. please refer to max/min/lof_softmax and update the strategy. ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming the input to batch_norm has the shape (batch, channel, spatial).
The implementation of batch_norm is already based on loop tiling, where tiling occurs along the spatial dimension while fully loading the batch dimension. This approach differs from the operators mentioned, such as max, min, which only require computations along a specified dimension without involving others. In the case of batch_norm, both the batch and spatial dimensions need to be fully loaded.
My previous consideration was that introducing tiling along the batch dimension would result in a nested loop structure, which may not be meaningful when batch is not large.
You might be suggesting that applying loop tiling on the batch dimension could indeed improve performance when the batch size is large, as it would allow for more contiguous memory access in the spatial dimension. For instance, if the batch size is 16,384 and the spatial size loaded per loop iteration is only 1, this could lead to inefficient memory access patterns. I will try implementing loop tiling on both the batch and spatial dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it
src/flag_gems/ops/batch_norm.py
Outdated
(curr_input - mean) * (curr_input - prev_mean), | ||
0.0, | ||
) | ||
var += tl.sum(deltas) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method cannot fully utilize vectorization/tensorization and leads to sequential computation.
src/flag_gems/ops/batch_norm.py
Outdated
BLOCK_SIZE_SPATIAL, spatial_dim - block_ind * BLOCK_SIZE_SPATIAL | ||
) | ||
curr_count = spatial_count * batch_dim | ||
count += curr_count |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can set reasonable default value of tl.load to avoid keeping counter.
src/flag_gems/ops/batch_norm.py
Outdated
) | ||
|
||
if affine: | ||
weight_grad += tl.sum(curr_pre_lin * curr_output_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
please fix bug in accuracy test:) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
* add batch_norm ops * add batch_norm ops * add batch_norm forward * small fix * add batch_norm ops * add batch_norm ops * add unit test * add batch_norm ops * add batch_norm perf * add batch_norm perf * add note * add libentry * fix rsqrt ci error * update unit tests * update unit test * update perf and tune config * Resolved a portion of the review comments * resolve runtime call error * resolve review suggestion
PR Category
Operator
Type of Change
New Feature
Description
Implement batch_norm operator
Issue
Progress