Skip to content

Conversation

@guangyey
Copy link
Contributor

Motivation

Fix #1905.
Refer to pytorch/pytorch#153666, add fused RMSNorm support on XPU.

@Copilot Copilot AI review requested due to automatic review settings October 22, 2025 09:46
@guangyey guangyey marked this pull request as draft October 22, 2025 09:46
@guangyey guangyey changed the title Fused RMSNorm implementation [WIP] Fused RMSNorm implementation Oct 22, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds fused RMSNorm (Root Mean Square Normalization) support for XPU devices to match PyTorch's recent implementation. RMSNorm is a simpler normalization technique compared to LayerNorm that eliminates the mean centering step.

Key Changes:

  • Adds forward and backward RMSNorm kernel registrations in the native functions YAML
  • Refactors existing LayerNorm kernels to support both LayerNorm and RMSNorm via a template parameter
  • Implements RMSNorm-specific computation paths that skip mean centering

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
yaml/native/native_functions.yaml Registers _fused_rms_norm and _fused_rms_norm_backward functions with XPU dispatch
src/ATen/native/xpu/sycl/LayerNormKernels.cpp Adds rms_norm template parameter to kernel functors and implements RMSNorm computation logic

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

mean_[i] = m1;
rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_);
} else {
rstd_[i] = c10::xpu::compat::rsqrt(m2 + m1 * m1 + eps_);
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RMSNorm formula appears incorrect. For RMSNorm, m1 (mean) should be zero since we skip mean computation. The formula should be rsqrt(m2 + eps_) where m2 represents the mean of squares. The term m1 * m1 should not be added.

Suggested change
rstd_[i] = c10::xpu::compat::rsqrt(m2 + m1 * m1 + eps_);
rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_);

Copilot uses AI. Check for mistakes.

static_cast<T_ACC>(rstd_[i]) * gamma_v +
beta_v;
} else {
Y_[index] = (static_cast<T_ACC>(X_[index])) *
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Remove the unnecessary inner parentheses around static_cast<T_ACC>(X_[index]). The expression can be simplified to Y_[index] = static_cast<T_ACC>(X_[index]) * static_cast<T_ACC>(rstd_[i]) * gamma_v;

Suggested change
Y_[index] = (static_cast<T_ACC>(X_[index])) *
Y_[index] = static_cast<T_ACC>(X_[index]) *

Copilot uses AI. Check for mistakes.

U new_mean = curr_sum.mean + delta * (1.f / new_count);
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
} else {
return {0.f, curr_sum.sigma2 + val * val, 0};
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return statement uses integer literal 0 for the mean and count fields, but the struct fields are of type float. For consistency and clarity, use 0.f for all three fields: return {0.f, curr_sum.sigma2 + val * val, 0.f};

Suggested change
return {0.f, curr_sum.sigma2 + val * val, 0};
return {0.f, curr_sum.sigma2 + val * val, 0.f};

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

implement fused rms norm

1 participant