Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add batch_norm #362

Merged
merged 19 commits into from
Jan 9, 2025
Merged

[Operator] Add batch_norm #362

merged 19 commits into from
Jan 9, 2025

Conversation

2niuhe
Copy link
Contributor

@2niuhe 2niuhe commented Dec 13, 2024

PR Category

Operator

Type of Change

New Feature

Description

Implement batch_norm operator

  • forward
  • backward

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

@StrongSpoon StrongSpoon self-assigned this Dec 16, 2024
@StrongSpoon StrongSpoon self-requested a review December 16, 2024 05:36
@triton.heuristics(
{
"BLOCK_SIZE_BATCH": lambda args: next_power_of_2(args["batch_dim"]),
"BLOCK_SIZE_SPATIAL": BLOCK_SIZE_SPATIAL_heuristic,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are working on changing one-tile algorithm into loop tiling. please refer to max/min/lof_softmax and update the strategy. ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the input to batch_norm has the shape (batch, channel, spatial).

The implementation of batch_norm is already based on loop tiling, where tiling occurs along the spatial dimension while fully loading the batch dimension. This approach differs from the operators mentioned, such as max, min, which only require computations along a specified dimension without involving others. In the case of batch_norm, both the batch and spatial dimensions need to be fully loaded.

My previous consideration was that introducing tiling along the batch dimension would result in a nested loop structure, which may not be meaningful when batch is not large.

You might be suggesting that applying loop tiling on the batch dimension could indeed improve performance when the batch size is large, as it would allow for more contiguous memory access in the spatial dimension. For instance, if the batch size is 16,384 and the spatial size loaded per loop iteration is only 1, this could lead to inefficient memory access patterns. I will try implementing loop tiling on both the batch and spatial dimensions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

benchmark/test_norm_perf.py Show resolved Hide resolved
@2niuhe
Copy link
Contributor Author

2niuhe commented Dec 16, 2024

co-author: @zhangboyue https://github.com/2niuhe/FlagGems/tree/dev_batch_norm

(curr_input - mean) * (curr_input - prev_mean),
0.0,
)
var += tl.sum(deltas)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method cannot fully utilize vectorization/tensorization and leads to sequential computation.

BLOCK_SIZE_SPATIAL, spatial_dim - block_ind * BLOCK_SIZE_SPATIAL
)
curr_count = spatial_count * batch_dim
count += curr_count
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can set reasonable default value of tl.load to avoid keeping counter.

)

if affine:
weight_grad += tl.sum(curr_pre_lin * curr_output_grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@StrongSpoon
Copy link
Collaborator

please fix bug in accuracy test:)

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

tests/test_norm_ops.py Show resolved Hide resolved
@StrongSpoon StrongSpoon merged commit ac17cd1 into FlagOpen:master Jan 9, 2025
8 of 9 checks passed
Gxiandy pushed a commit to Gxiandy/FlagGems that referenced this pull request Jan 12, 2025
* add batch_norm ops

* add batch_norm ops

* add batch_norm forward

* small fix

* add batch_norm ops

* add batch_norm ops

* add unit test

* add batch_norm ops

* add batch_norm perf

* add batch_norm perf

* add note

* add libentry

* fix rsqrt ci error

* update unit tests

* update unit test

* update perf and tune config

* Resolved a portion of the review comments

* resolve runtime call error

* resolve review suggestion
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants