-
Notifications
You must be signed in to change notification settings - Fork 60
[WIP] Fused RMSNorm implementation #2205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds fused RMSNorm (Root Mean Square Normalization) support for XPU devices to match PyTorch's recent implementation. RMSNorm is a simpler normalization technique compared to LayerNorm that eliminates the mean centering step.
Key Changes:
- Adds forward and backward RMSNorm kernel registrations in the native functions YAML
- Refactors existing LayerNorm kernels to support both LayerNorm and RMSNorm via a template parameter
- Implements RMSNorm-specific computation paths that skip mean centering
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
yaml/native/native_functions.yaml | Registers _fused_rms_norm and _fused_rms_norm_backward functions with XPU dispatch |
src/ATen/native/xpu/sycl/LayerNormKernels.cpp | Adds rms_norm template parameter to kernel functors and implements RMSNorm computation logic |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
mean_[i] = m1; | ||
rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_); | ||
} else { | ||
rstd_[i] = c10::xpu::compat::rsqrt(m2 + m1 * m1 + eps_); |
Copilot
AI
Oct 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RMSNorm formula appears incorrect. For RMSNorm, m1
(mean) should be zero since we skip mean computation. The formula should be rsqrt(m2 + eps_)
where m2
represents the mean of squares. The term m1 * m1
should not be added.
rstd_[i] = c10::xpu::compat::rsqrt(m2 + m1 * m1 + eps_); | |
rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_); |
Copilot uses AI. Check for mistakes.
static_cast<T_ACC>(rstd_[i]) * gamma_v + | ||
beta_v; | ||
} else { | ||
Y_[index] = (static_cast<T_ACC>(X_[index])) * |
Copilot
AI
Oct 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Remove the unnecessary inner parentheses around static_cast<T_ACC>(X_[index])
. The expression can be simplified to Y_[index] = static_cast<T_ACC>(X_[index]) * static_cast<T_ACC>(rstd_[i]) * gamma_v;
Y_[index] = (static_cast<T_ACC>(X_[index])) * | |
Y_[index] = static_cast<T_ACC>(X_[index]) * |
Copilot uses AI. Check for mistakes.
U new_mean = curr_sum.mean + delta * (1.f / new_count); | ||
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; | ||
} else { | ||
return {0.f, curr_sum.sigma2 + val * val, 0}; |
Copilot
AI
Oct 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return statement uses integer literal 0
for the mean and count fields, but the struct fields are of type float
. For consistency and clarity, use 0.f
for all three fields: return {0.f, curr_sum.sigma2 + val * val, 0.f};
return {0.f, curr_sum.sigma2 + val * val, 0}; | |
return {0.f, curr_sum.sigma2 + val * val, 0.f}; |
Copilot uses AI. Check for mistakes.
Motivation
Fix #1905.
Refer to pytorch/pytorch#153666, add fused RMSNorm support on XPU.