Skip to content

Conversation

qianjinqiu
Copy link

PR Category

Type of Change

Description

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Copy link
Contributor

Summary of Changes

Hello @qianjinqiu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Ascend backend's capabilities by integrating a broad array of new operators and optimizing existing ones through Triton kernel implementations. The changes aim to improve computational efficiency and expand the functional scope of the backend, ensuring better support for diverse deep learning models. Additionally, the update includes refinements to performance heuristics and standardization of debug logging for improved maintainability and clarity.

Highlights

  • Expanded Operator Coverage: A significant number of new operators have been added to the Ascend backend, including rms_norm, cat, cumsum, diag, embedding, flip, full, gather, index, linspace, masked_fill, multinomial, ones, outer, polar, randperm, repeat_interleave, resolve_neg, select_scatter, sort, threshold, unique, var_mean, vector_norm, vstack, where, and zeros variants.
  • New Fused Operators: Three new fused operators, apply_rotary_pos_emb, fused_add_rms_norm, and skip_layer_norm, have been introduced to enhance performance by combining common operations into single kernels.
  • Triton Kernel Implementations: All newly added and many existing operators now leverage Triton kernels for optimized execution on Ascend hardware, demonstrating a shift towards more efficient low-level implementations.
  • Heuristics Configuration Updates: Performance heuristics for various operators like argmin, mm, rand, randn, uniform, and batch_norm have been refined, and num_warps and num_stages parameters have been removed from the tuning configurations for many operators, simplifying the tuning process.
  • Debug Logging Standardization: Debug log messages across several operators (e.g., cross_entropy_loss, addmm, all, amax, fill, max, min, mm, pow, triu) have been standardized to include 'GEMS_ASCEND' for clearer backend identification.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant number of new operators and fused kernels for the Ascend backend. While this is a great step towards feature parity, the implementation of several new Triton kernels contains critical correctness issues, such as using invalid syntax (Python loops inside JIT functions), which will prevent them from compiling. Additionally, there are multiple instances of severe performance anti-patterns, including launching kernels within loops and performing redundant multi-pass computations over the same data. I have also identified several maintainability concerns like dead code, non-English comments, and confusing logic. Addressing these issues is crucial for the stability, performance, and long-term health of the Ascend backend.

Comment on lines +45 to +66
def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_SUB: tl.constexpr):
pid = tl.program_id(0).to(tl.int64)
base_offset = pid * BLOCK_SIZE

# 初始化累加器
total_sum = 0.0

# 计算需要处理的总块数
num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB

# 循环处理每个子块
for sub_block_idx in range(num_sub_blocks):
sub_offset = base_offset + sub_block_idx * BLOCK_SIZE_SUB
offsets = sub_offset + tl.arange(0, BLOCK_SIZE_SUB)
mask = offsets < M

x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32)
sub_sum = tl.sum(x * x) # 计算当前子块的平方和
total_sum += sub_sum # 累积到总和中

# 存储累积的平方和到Mid数组
tl.store(Mid + pid, total_sum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The kernel l2_norm_kernel_1 (and similarly l2_norm_kernel_2) incorrectly uses a Python for loop (for sub_block_idx in range(num_sub_blocks):) inside a @triton.jit function. This is invalid syntax because the loop range num_sub_blocks is not a compile-time constant. This will cause a compilation error. The reduction logic needs to be rewritten using a single Triton-native loop over the data elements, which is the standard and correct way to perform reductions in Triton.

@triton.jit
def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0).to(tl.int64)
    
    total_sum = 0.0
    # Loop over the data in chunks, not sub-blocks
    for off in range(0, BLOCK_SIZE, 512): # Use a constexpr for sub-block size
        offsets = pid * BLOCK_SIZE + off + tl.arange(0, 512)
        mask = offsets < M
        x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32)
        total_sum += tl.sum(x * x)
    
    tl.store(Mid + pid, total_sum)

Comment on lines +15 to +32
def zeros_kernel(
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE_SUB: tl.constexpr,
):
pid = tle.program_id(axis=0)
block_start = pid * BLOCK_SIZE

# 计算需要处理的总块数
num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB
num_sub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_SIZE_SUB)

#循环处理每个子块
for sub_block_idx in range(num_sub_blocks):
sub_offset = block_start + sub_block_idx * BLOCK_SIZE_SUB + tl.arange(0, BLOCK_SIZE_SUB)
mask = sub_offset < n_elements
tl.store(output_ptr + sub_offset, 0.0, mask=mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The zeros_kernel uses a Python for sub_block_idx in range(num_sub_blocks): loop within a JIT-compiled function. This is invalid because num_sub_blocks is not a compile-time constant, and Python loops cannot be used this way inside Triton kernels. This will result in a compilation failure. The kernel should be simplified to a standard grid-stride loop pattern without the incorrect inner Python loop.

Suggested change
def zeros_kernel(
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE_SUB: tl.constexpr,
):
pid = tle.program_id(axis=0)
block_start = pid * BLOCK_SIZE
# 计算需要处理的总块数
num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB
num_sub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_SIZE_SUB)
#循环处理每个子块
for sub_block_idx in range(num_sub_blocks):
sub_offset = block_start + sub_block_idx * BLOCK_SIZE_SUB + tl.arange(0, BLOCK_SIZE_SUB)
mask = sub_offset < n_elements
tl.store(output_ptr + sub_offset, 0.0, mask=mask)
@triton.jit
def zeros_kernel(
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tle.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
tl.store(output_ptr + offsets, 0.0, mask=mask)

Comment on lines +62 to +122
def skip_layer_norm_kernel_1(
Y, # pointer to the output
X, # pointer to the input
R, # pointer to the residual
W, # pointer to the weights
B, # pointer to the biases
y_stride_r,
y_stride_c,
x_stride_r, # how much to increase the pointer when moving by 1 row
x_stride_c, # how much to increase the pointer when moving by 1 col
r_stride_r, # how much to increase the pointer when moving by 1 row
r_stride_c, # how much to increase the pointer when moving by 1 col
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
pid = tle.program_id(0)
loops = tl.cdiv(N, BLOCK_SIZE)
sum_x = tl.zeros((), dtype=tl.float32) # 显式指定为float32
X += pid * x_stride_r
R += pid * r_stride_r
Y += pid * y_stride_r
# 这个切分比较特别, 需要一次load进来整行的数据再进行计算,所以需要传N进来
# 分块读取整维数据
for process in range(loops):
cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
x += r
sum_x += tl.sum(x, axis=0)
# 计算整体均值
mean = sum_x / N
sum_var = tl.zeros((), dtype=tl.float32) # 显式指定为float32
# 分块计算标准差
for process in range(loops):
cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
x += r
# Compute variance
_var = tl.where(mask, x - mean, 0.0)
_var = _var * _var
sum_var += tl.sum(_var, axis=0)
# 计算整体标准差
var = sum_var / N
rstd = 1 / tl.sqrt(var + eps)
# 分块计算结果
for process in range(loops):
cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)
r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32)
x += r
x_hat = (x - mean) * rstd
y = w * x_hat + b
y = y.to(Y.dtype.element_ty)
tl.store(Y + cols * y_stride_c, y, mask=mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation has several issues:

  1. High Inefficiency: The skip_layer_norm_kernel_1 kernel performs three separate passes over the input data to compute the mean, variance, and final output. In each pass, it re-loads tensors X and R and re-computes their sum. This is highly inefficient and can be optimized into a single pass using Welford's algorithm or a more efficient two-pass approach.
  2. Maintainability: The kernel contains comments in Chinese (e.g., line 80, 84, 93). For consistency and broader accessibility, all code and comments should be in English.
  3. Dead Code: The skip_layer_norm_kernel function is defined but never used. It should be removed to improve code clarity.

Comment on lines +121 to +143
if rotary_interleaved:
for d in range(0, BLOCK_D // 2):
dim_range_x = d * 2
dim_range_y = d * 2 + 1

rotary_embedding_rw_kernel(
state_out,
state,
cos,
sin,
stride_state_n,
stride_state_h,
stride_state_d,
stride_cos_n,
stride_cos_d,
num_tokens,
num_heads,
token_range,
head_range,
dim_range_x,
dim_range_y,
rotary_interleaved,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation for the rotary_interleaved=True case in rotary_embedding_siso_kernel is a major performance bottleneck. It uses a Python for loop to launch a separate kernel (rotary_embedding_rw_kernel) for each pair of dimensions. Launching kernels in a loop is a significant anti-pattern in Triton that leads to very poor performance due to high launch overhead. This logic should be vectorized to happen within a single kernel launch, similar to how the non-interleaved case is handled.

try:
import torch_npu # noqa: F401
atan2 = tl_extra_shim.atan2
except: # noqa: E722
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of a bare except: clause is a bad practice as it can catch unexpected exceptions (like KeyboardInterrupt or SystemExit), making the program difficult to debug or terminate gracefully. The exception should be more specific.

Suggested change
except: # noqa: E722
except ImportError: # noqa: E722

Comment on lines +63 to +64
temp = tl.full((BLOCK_SIZE,), (pid + 1001) // (pid + 1000), tl.int64)
inp_offset = base + cur_index * inp_dim_stride * temp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The temp variable is calculated using (pid + 1001) // (pid + 1000), which will always evaluate to 1 since pid is a non-negative integer. The subsequent multiplication by temp is redundant and makes the code unnecessarily confusing. This variable should be removed for clarity.

Suggested change
temp = tl.full((BLOCK_SIZE,), (pid + 1001) // (pid + 1000), tl.int64)
inp_offset = base + cur_index * inp_dim_stride * temp
inp_offset = base + cur_index * inp_dim_stride

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant