Skip to content

Commit 298d4fd

Browse files
committed
[Ascend]Adapt to Most Operators
1 parent d6c1e88 commit 298d4fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+6589
-349
lines changed

src/flag_gems/runtime/backend/_ascend/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
vendor_info = VendorInfoBase(
44
vendor_name="ascend",
55
device_name="npu",
6-
triton_extra_name="ascend",
76
device_query_cmd="npu-smi info",
87
dispatch_key="PrivateUse1",
98
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
from .cross_entropy_loss import cross_entropy_loss
2+
from .rotary_embedding import apply_rotary_pos_emb
3+
from .fused_add_rms_norm import fused_add_rms_norm
4+
from .skip_layernorm import skip_layer_norm
25

36
__all__ = [
47
"cross_entropy_loss",
8+
"apply_rotary_pos_emb",
9+
"fused_add_rms_norm",
10+
"skip_layer_norm",
511
]

src/flag_gems/runtime/backend/_ascend/fused/cross_entropy_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def sum_and_scale(
519519
class CrossEntropyLoss(torch.autograd.Function):
520520
@staticmethod
521521
def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
522-
logger.debug("GEMS CrossEntropyLoss")
522+
logger.debug("GEMS_ASCEND CrossEntropyLoss")
523523

524524
shape = list(inp.shape)
525525
dim = inp.ndim
@@ -607,7 +607,7 @@ def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
607607

608608
@staticmethod
609609
def backward(ctx, out_grad):
610-
logger.debug("GEMS CrossEntropyLoss VJP")
610+
logger.debug("GEMS_ASCEND CrossEntropyLoss VJP")
611611

612612
inp, tgt, weight = ctx.saved_tensors
613613
N = ctx.N
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import logging
2+
import math
3+
4+
import triton
5+
import triton.language as tl
6+
7+
from flag_gems.runtime import torch_device_fn
8+
from flag_gems.utils import libentry
9+
from flag_gems.utils import triton_lang_extension as tle
10+
11+
logger = logging.getLogger(__name__)
12+
13+
@libentry()
14+
@triton.jit(do_not_specialize=["eps"])
15+
def fused_add_rms_norm_kernel(
16+
X, # pointer to the input
17+
R, # pointer to the residual
18+
W, # pointer to the weight
19+
x_stride_r, # how much to increase the pointer when moving by 1 row
20+
x_stride_c, # how much to increase the pointer when moving by 1 col
21+
r_stride_r, # how much to increase the pointer when moving by 1 row
22+
r_stride_c, # how much to increase the pointer when moving by 1 col
23+
N, # number of columns in X
24+
eps, # epsilon to avoid division by zero
25+
BLOCK_SIZE: tl.constexpr,
26+
):
27+
pid = tl.program_id(0)
28+
X += pid * x_stride_r
29+
R += pid * r_stride_r
30+
31+
_var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
32+
33+
for off in range(0, N, BLOCK_SIZE):
34+
cols = off + tl.arange(0, BLOCK_SIZE)
35+
mask = cols < N
36+
x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
37+
r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
38+
x += r
39+
_var_base += x * x / N
40+
var = tl.sum(_var_base)
41+
42+
rrms = 1 / tl.sqrt(var + eps)
43+
44+
for off in range(0, N, BLOCK_SIZE):
45+
cols = off + tl.arange(0, BLOCK_SIZE)
46+
mask = cols < N
47+
x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
48+
r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
49+
x += r
50+
w = tl.load(W + cols, mask, other=0.0)
51+
y = (x * rrms).to(X.dtype.element_ty) * w
52+
# write back to residual and input
53+
tl.store(R + cols * r_stride_c, x, mask=mask)
54+
tl.store(X + cols * x_stride_c, y, mask=mask)
55+
56+
57+
def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
58+
"""
59+
This function performs fused residual addition and RMS normalization **in-place**.
60+
Both `x` and `residual` tensors will be modified. Use with caution if these tensors
61+
are reused elsewhere or require gradients.
62+
"""
63+
logger.debug("GEMS_ASCEND FUSED_ADD_RMS_NORM FORWARD")
64+
dim = x.ndim - len(normalized_shape)
65+
M = min(math.prod(x.shape[:dim]), 65535)
66+
N = math.prod(normalized_shape)
67+
68+
BLOCK_SIZE = min(triton.next_power_of_2(N), 8192)
69+
x = x.contiguous()
70+
residual = residual.contiguous()
71+
weight = weight.contiguous()
72+
73+
with torch_device_fn.device(x.device):
74+
fused_add_rms_norm_kernel[M,](
75+
x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
76+
)
77+
return x, residual
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
import logging
2+
from typing import Optional
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
8+
import flag_gems
9+
from flag_gems.runtime import torch_device_fn
10+
from flag_gems.utils import libentry
11+
from flag_gems.utils import triton_lang_extension as tle
12+
13+
14+
@triton.jit
15+
def rotary_embedding_rw_kernel(
16+
state_out,
17+
state,
18+
cos,
19+
sin,
20+
stride_state_n,
21+
stride_state_h,
22+
stride_state_d,
23+
stride_cos_n,
24+
stride_cos_d,
25+
num_tokens,
26+
num_heads,
27+
token_range,
28+
head_range,
29+
dim_range_x,
30+
dim_range_y,
31+
rotary_interleaved: tl.constexpr,
32+
):
33+
state_x_offset = (
34+
token_range[:, None, None] * stride_state_n
35+
+ head_range[None, :, None] * stride_state_h
36+
+ dim_range_x[None, None, :] * stride_state_d
37+
)
38+
state_y_offset = (
39+
token_range[:, None, None] * stride_state_n
40+
+ head_range[None, :, None] * stride_state_h
41+
+ dim_range_y[None, None, :] * stride_state_d
42+
)
43+
44+
cos_sim_offset = (
45+
token_range[:, None, None] * stride_cos_n
46+
+ dim_range_x[None, None, :] * stride_cos_d
47+
)
48+
if rotary_interleaved:
49+
sin_sim_offset = (
50+
token_range[:, None, None] * stride_cos_n
51+
+ dim_range_y[None, None, :] * stride_cos_d
52+
)
53+
else:
54+
sin_sim_offset = cos_sim_offset
55+
56+
state_x = tl.load(
57+
state + state_x_offset,
58+
mask=(token_range[:, None, None] < num_tokens)
59+
& (head_range[None, :, None] < num_heads),
60+
other=0.0,
61+
)
62+
state_y = tl.load(
63+
state + state_y_offset,
64+
mask=(token_range[:, None, None] < num_tokens)
65+
& (head_range[None, :, None] < num_heads),
66+
other=0.0,
67+
)
68+
69+
cos_loaded = tl.load(
70+
cos + cos_sim_offset,
71+
mask=token_range[:, None, None] < num_tokens,
72+
other=0.0,
73+
).to(tl.float32)
74+
sin_loaded = tl.load(
75+
sin + sin_sim_offset,
76+
mask=token_range[:, None, None] < num_tokens,
77+
other=0.0,
78+
).to(tl.float32)
79+
80+
out_x = state_x * cos_loaded - state_y * sin_loaded
81+
out_y = state_x * sin_loaded + state_y * cos_loaded
82+
83+
tl.store(
84+
state_out + state_x_offset,
85+
out_x,
86+
mask=(token_range[:, None, None] < num_tokens)
87+
& (head_range[None, :, None] < num_heads),
88+
)
89+
tl.store(
90+
state_out + state_y_offset,
91+
out_y,
92+
mask=(token_range[:, None, None] < num_tokens)
93+
& (head_range[None, :, None] < num_heads),
94+
)
95+
96+
97+
@libentry()
98+
@triton.jit
99+
def rotary_embedding_siso_kernel(
100+
state_out, # [num_tokens, head_num, head_dim]
101+
state, # [num_tokens, head_num, head_dim]
102+
cos, # [num_tokens, 1, head_dim // 2]
103+
sin, # [num_tokens, 1, head_dim // 2]
104+
stride_state_n,
105+
stride_state_h,
106+
stride_state_d,
107+
stride_cos_n,
108+
stride_cos_d,
109+
num_tokens,
110+
num_heads,
111+
BLOCK_N: tl.constexpr,
112+
BLOCK_H: tl.constexpr,
113+
BLOCK_D: tl.constexpr,
114+
rotary_interleaved: tl.constexpr,
115+
):
116+
token_index = tl.program_id(0)
117+
token_range = token_index * BLOCK_N + tl.arange(0, BLOCK_N)
118+
head_index = tl.program_id(1)
119+
head_range = head_index * BLOCK_H + tl.arange(0, BLOCK_H)
120+
121+
if rotary_interleaved:
122+
for d in range(0, BLOCK_D // 2):
123+
dim_range_x = d * 2
124+
dim_range_y = d * 2 + 1
125+
126+
rotary_embedding_rw_kernel(
127+
state_out,
128+
state,
129+
cos,
130+
sin,
131+
stride_state_n,
132+
stride_state_h,
133+
stride_state_d,
134+
stride_cos_n,
135+
stride_cos_d,
136+
num_tokens,
137+
num_heads,
138+
token_range,
139+
head_range,
140+
dim_range_x,
141+
dim_range_y,
142+
rotary_interleaved,
143+
)
144+
else:
145+
dim_range_x = tl.arange(0, BLOCK_D // 2)
146+
dim_range_y = tl.arange(BLOCK_D // 2, BLOCK_D)
147+
rotary_embedding_rw_kernel(
148+
state_out,
149+
state,
150+
cos,
151+
sin,
152+
stride_state_n,
153+
stride_state_h,
154+
stride_state_d,
155+
stride_cos_n,
156+
stride_cos_d,
157+
num_tokens,
158+
num_heads,
159+
token_range,
160+
head_range,
161+
dim_range_x,
162+
dim_range_y,
163+
rotary_interleaved,
164+
)
165+
166+
def apply_rotary_pos_emb(
167+
q,
168+
k,
169+
cos,
170+
sin,
171+
position_ids: Optional[torch.IntTensor] = None,
172+
rotary_interleaved: bool = False,
173+
):
174+
"""
175+
Apply rotary position embedding to q and k
176+
177+
Args:
178+
q: (*, q_heads, head_dim)
179+
k: (*, k_heads, head_dim)
180+
cos: (max_seq_len, head_dim // 2)
181+
sin: (max_seq_len, head_dim // 2)
182+
position_ids: (*, ), optional, position ids for each token
183+
rotary_interleaved: whether the head_dim is rotated in an interleaved way
184+
185+
Returns:
186+
q_embed: (*, q_heads, head_dim)
187+
k_embed: (*, k_heads, head_dim)
188+
"""
189+
logging.debug("GEMS_ASCEND ROTARY POS EMBEDDING")
190+
assert (
191+
k.shape[-1] == q.shape[-1]
192+
), f"q and k must have the same last dimension, got {q.shape} and {k.shape}"
193+
assert (
194+
cos.shape[-1] == sin.shape[-1]
195+
), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}"
196+
assert (
197+
cos.shape[-1] * 2 == q.shape[-1]
198+
), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}"
199+
assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension"
200+
assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension"
201+
202+
q_shape = q.shape
203+
k_shape = k.shape
204+
205+
assert (
206+
q.shape[:-2] == k.shape[:-2]
207+
), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}"
208+
if position_ids is None:
209+
assert (
210+
len(q.shape) == 4
211+
), f"q must have 4 dimensions if position_ids is not provided, got {q.shape}"
212+
seq_len = q.shape[-3]
213+
else:
214+
assert (
215+
position_ids.shape == q.shape[:-2]
216+
), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}"
217+
218+
position_ids = position_ids.view(-1)
219+
seq_len = None
220+
221+
q = q.view(-1, q.shape[-2], q.shape[-1])
222+
k = k.view(-1, k.shape[-2], k.shape[-1])
223+
224+
q_embed = torch.empty_like(q)
225+
k_embed = torch.empty_like(k)
226+
227+
def torch_rotary_embedding(state_out, state, cos, sin):
228+
num_tokens = state.shape[0]
229+
num_heads = state.shape[1]
230+
head_dim = state.shape[-1]
231+
232+
BLOCK_N = 8
233+
BLOCK_H = 4
234+
grid = (
235+
triton.cdiv(num_tokens, BLOCK_N),
236+
triton.cdiv(num_heads, BLOCK_H),
237+
)
238+
with torch_device_fn.device(state_out.device):
239+
with flag_gems.use_gems():
240+
if position_ids is None:
241+
cos = cos[: q_shape[-3], None, :]
242+
sin = sin[: q_shape[-3], None, :]
243+
else:
244+
cos = cos[position_ids, None, :]
245+
sin = sin[position_ids, None, :]
246+
247+
if rotary_interleaved:
248+
cos = torch.repeat_interleave(cos, 2, dim=-1)
249+
sin = torch.repeat_interleave(sin, 2, dim=-1)
250+
orig_cos = cos
251+
orig_sin = sin
252+
for _ in range(q_shape[0] - 1):
253+
cos = torch.cat((cos, orig_cos), dim=0)
254+
sin = torch.cat((sin, orig_sin), dim=0)
255+
rotary_embedding_siso_kernel[grid](
256+
state_out,
257+
state,
258+
cos,
259+
sin,
260+
state.stride(0),
261+
state.stride(1),
262+
state.stride(2),
263+
cos.stride(0),
264+
cos.stride(2),
265+
num_tokens,
266+
num_heads,
267+
BLOCK_N=BLOCK_N,
268+
BLOCK_H=BLOCK_H,
269+
BLOCK_D=head_dim,
270+
rotary_interleaved=rotary_interleaved,
271+
)
272+
273+
torch_rotary_embedding(q_embed, q, cos, sin)
274+
torch_rotary_embedding(k_embed, k, cos, sin)
275+
276+
q_embed = q_embed.view(q_shape)
277+
k_embed = k_embed.view(k_shape)
278+
return q_embed, k_embed

0 commit comments

Comments
 (0)