-
Notifications
You must be signed in to change notification settings - Fork 144
[Ascend]Adapt to Most Operators #1002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @qianjinqiu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Ascend backend's capabilities by integrating a broad array of new operators and optimizing existing ones through Triton kernel implementations. The changes aim to improve computational efficiency and expand the functional scope of the backend, ensuring better support for diverse deep learning models. Additionally, the update includes refinements to performance heuristics and standardization of debug logging for improved maintainability and clarity. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant number of new operators and fused kernels for the Ascend backend. While this is a great step towards feature parity, the implementation of several new Triton kernels contains critical correctness issues, such as using invalid syntax (Python loops inside JIT functions), which will prevent them from compiling. Additionally, there are multiple instances of severe performance anti-patterns, including launching kernels within loops and performing redundant multi-pass computations over the same data. I have also identified several maintainability concerns like dead code, non-English comments, and confusing logic. Addressing these issues is crucial for the stability, performance, and long-term health of the Ascend backend.
def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_SUB: tl.constexpr): | ||
pid = tl.program_id(0).to(tl.int64) | ||
base_offset = pid * BLOCK_SIZE | ||
|
||
# 初始化累加器 | ||
total_sum = 0.0 | ||
|
||
# 计算需要处理的总块数 | ||
num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB | ||
|
||
# 循环处理每个子块 | ||
for sub_block_idx in range(num_sub_blocks): | ||
sub_offset = base_offset + sub_block_idx * BLOCK_SIZE_SUB | ||
offsets = sub_offset + tl.arange(0, BLOCK_SIZE_SUB) | ||
mask = offsets < M | ||
|
||
x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32) | ||
sub_sum = tl.sum(x * x) # 计算当前子块的平方和 | ||
total_sum += sub_sum # 累积到总和中 | ||
|
||
# 存储累积的平方和到Mid数组 | ||
tl.store(Mid + pid, total_sum) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel l2_norm_kernel_1
(and similarly l2_norm_kernel_2
) incorrectly uses a Python for
loop (for sub_block_idx in range(num_sub_blocks):
) inside a @triton.jit
function. This is invalid syntax because the loop range num_sub_blocks
is not a compile-time constant. This will cause a compilation error. The reduction logic needs to be rewritten using a single Triton-native loop over the data elements, which is the standard and correct way to perform reductions in Triton.
@triton.jit
def l2_norm_kernel_1(X, Mid, M, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0).to(tl.int64)
total_sum = 0.0
# Loop over the data in chunks, not sub-blocks
for off in range(0, BLOCK_SIZE, 512): # Use a constexpr for sub-block size
offsets = pid * BLOCK_SIZE + off + tl.arange(0, 512)
mask = offsets < M
x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32)
total_sum += tl.sum(x * x)
tl.store(Mid + pid, total_sum)
def zeros_kernel( | ||
output_ptr, | ||
n_elements, | ||
BLOCK_SIZE: tl.constexpr, | ||
BLOCK_SIZE_SUB: tl.constexpr, | ||
): | ||
pid = tle.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
|
||
# 计算需要处理的总块数 | ||
num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB | ||
num_sub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_SIZE_SUB) | ||
|
||
#循环处理每个子块 | ||
for sub_block_idx in range(num_sub_blocks): | ||
sub_offset = block_start + sub_block_idx * BLOCK_SIZE_SUB + tl.arange(0, BLOCK_SIZE_SUB) | ||
mask = sub_offset < n_elements | ||
tl.store(output_ptr + sub_offset, 0.0, mask=mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The zeros_kernel
uses a Python for sub_block_idx in range(num_sub_blocks):
loop within a JIT-compiled function. This is invalid because num_sub_blocks
is not a compile-time constant, and Python loops cannot be used this way inside Triton kernels. This will result in a compilation failure. The kernel should be simplified to a standard grid-stride loop pattern without the incorrect inner Python loop.
def zeros_kernel( | |
output_ptr, | |
n_elements, | |
BLOCK_SIZE: tl.constexpr, | |
BLOCK_SIZE_SUB: tl.constexpr, | |
): | |
pid = tle.program_id(axis=0) | |
block_start = pid * BLOCK_SIZE | |
# 计算需要处理的总块数 | |
num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB | |
num_sub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_SIZE_SUB) | |
#循环处理每个子块 | |
for sub_block_idx in range(num_sub_blocks): | |
sub_offset = block_start + sub_block_idx * BLOCK_SIZE_SUB + tl.arange(0, BLOCK_SIZE_SUB) | |
mask = sub_offset < n_elements | |
tl.store(output_ptr + sub_offset, 0.0, mask=mask) | |
@triton.jit | |
def zeros_kernel( | |
output_ptr, | |
n_elements, | |
BLOCK_SIZE: tl.constexpr, | |
): | |
pid = tle.program_id(axis=0) | |
block_start = pid * BLOCK_SIZE | |
offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
mask = offsets < n_elements | |
tl.store(output_ptr + offsets, 0.0, mask=mask) |
def skip_layer_norm_kernel_1( | ||
Y, # pointer to the output | ||
X, # pointer to the input | ||
R, # pointer to the residual | ||
W, # pointer to the weights | ||
B, # pointer to the biases | ||
y_stride_r, | ||
y_stride_c, | ||
x_stride_r, # how much to increase the pointer when moving by 1 row | ||
x_stride_c, # how much to increase the pointer when moving by 1 col | ||
r_stride_r, # how much to increase the pointer when moving by 1 row | ||
r_stride_c, # how much to increase the pointer when moving by 1 col | ||
N, # number of columns in X | ||
eps, # epsilon to avoid division by zero | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
pid = tle.program_id(0) | ||
loops = tl.cdiv(N, BLOCK_SIZE) | ||
sum_x = tl.zeros((), dtype=tl.float32) # 显式指定为float32 | ||
X += pid * x_stride_r | ||
R += pid * r_stride_r | ||
Y += pid * y_stride_r | ||
# 这个切分比较特别, 需要一次load进来整行的数据再进行计算,所以需要传N进来 | ||
# 分块读取整维数据 | ||
for process in range(loops): | ||
cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) | ||
r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) | ||
x += r | ||
sum_x += tl.sum(x, axis=0) | ||
# 计算整体均值 | ||
mean = sum_x / N | ||
sum_var = tl.zeros((), dtype=tl.float32) # 显式指定为float32 | ||
# 分块计算标准差 | ||
for process in range(loops): | ||
cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) | ||
r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) | ||
x += r | ||
# Compute variance | ||
_var = tl.where(mask, x - mean, 0.0) | ||
_var = _var * _var | ||
sum_var += tl.sum(_var, axis=0) | ||
# 计算整体标准差 | ||
var = sum_var / N | ||
rstd = 1 / tl.sqrt(var + eps) | ||
# 分块计算结果 | ||
for process in range(loops): | ||
cols = process * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) | ||
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) | ||
x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) | ||
r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) | ||
x += r | ||
x_hat = (x - mean) * rstd | ||
y = w * x_hat + b | ||
y = y.to(Y.dtype.element_ty) | ||
tl.store(Y + cols * y_stride_c, y, mask=mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation has several issues:
- High Inefficiency: The
skip_layer_norm_kernel_1
kernel performs three separate passes over the input data to compute the mean, variance, and final output. In each pass, it re-loads tensorsX
andR
and re-computes their sum. This is highly inefficient and can be optimized into a single pass using Welford's algorithm or a more efficient two-pass approach. - Maintainability: The kernel contains comments in Chinese (e.g., line 80, 84, 93). For consistency and broader accessibility, all code and comments should be in English.
- Dead Code: The
skip_layer_norm_kernel
function is defined but never used. It should be removed to improve code clarity.
if rotary_interleaved: | ||
for d in range(0, BLOCK_D // 2): | ||
dim_range_x = d * 2 | ||
dim_range_y = d * 2 + 1 | ||
|
||
rotary_embedding_rw_kernel( | ||
state_out, | ||
state, | ||
cos, | ||
sin, | ||
stride_state_n, | ||
stride_state_h, | ||
stride_state_d, | ||
stride_cos_n, | ||
stride_cos_d, | ||
num_tokens, | ||
num_heads, | ||
token_range, | ||
head_range, | ||
dim_range_x, | ||
dim_range_y, | ||
rotary_interleaved, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation for the rotary_interleaved=True
case in rotary_embedding_siso_kernel
is a major performance bottleneck. It uses a Python for
loop to launch a separate kernel (rotary_embedding_rw_kernel
) for each pair of dimensions. Launching kernels in a loop is a significant anti-pattern in Triton that leads to very poor performance due to high launch overhead. This logic should be vectorized to happen within a single kernel launch, similar to how the non-interleaved case is handled.
try: | ||
import torch_npu # noqa: F401 | ||
atan2 = tl_extra_shim.atan2 | ||
except: # noqa: E722 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of a bare except:
clause is a bad practice as it can catch unexpected exceptions (like KeyboardInterrupt
or SystemExit
), making the program difficult to debug or terminate gracefully. The exception should be more specific.
except: # noqa: E722 | |
except ImportError: # noqa: E722 |
temp = tl.full((BLOCK_SIZE,), (pid + 1001) // (pid + 1000), tl.int64) | ||
inp_offset = base + cur_index * inp_dim_stride * temp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The temp
variable is calculated using (pid + 1001) // (pid + 1000)
, which will always evaluate to 1
since pid
is a non-negative integer. The subsequent multiplication by temp
is redundant and makes the code unnecessarily confusing. This variable should be removed for clarity.
temp = tl.full((BLOCK_SIZE,), (pid + 1001) // (pid + 1000), tl.int64) | |
inp_offset = base + cur_index * inp_dim_stride * temp | |
inp_offset = base + cur_index * inp_dim_stride |
PR Category
Type of Change
Description
Issue
Progress
Performance