Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/flag_gems/runtime/backend/_ascend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
vendor_info = VendorInfoBase(
vendor_name="ascend",
device_name="npu",
triton_extra_name="ascend",
device_query_cmd="npu-smi info",
dispatch_key="PrivateUse1",
)

CUSTOMIZED_UNUSED_OPS = ("contiguous",)
CUSTOMIZED_UNUSED_OPS = (
"contiguous",
"sort",
"sort_stable",
"topk",
)


__all__ = ["*"]
6 changes: 6 additions & 0 deletions src/flag_gems/runtime/backend/_ascend/fused/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .cross_entropy_loss import cross_entropy_loss
from .rotary_embedding import apply_rotary_pos_emb
from .fused_add_rms_norm import fused_add_rms_norm
from .skip_layernorm import skip_layer_norm

__all__ = [
"cross_entropy_loss",
"apply_rotary_pos_emb",
"fused_add_rms_norm",
"skip_layer_norm",
]
16 changes: 11 additions & 5 deletions src/flag_gems/runtime/backend/_ascend/fused/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import triton
import triton.language as tl
from torch.nn import _reduction as _Reduction

from flag_gems import runtime
from flag_gems.runtime import torch_device_fn
Expand Down Expand Up @@ -519,7 +520,7 @@ def sum_and_scale(
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):
logger.debug("GEMS CrossEntropyLoss")
logger.debug("GEMS_ASCEND CrossEntropyLoss")

shape = list(inp.shape)
dim = inp.ndim
Expand Down Expand Up @@ -607,7 +608,7 @@ def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing):

@staticmethod
def backward(ctx, out_grad):
logger.debug("GEMS CrossEntropyLoss VJP")
logger.debug("GEMS_ASCEND CrossEntropyLoss VJP")

inp, tgt, weight = ctx.saved_tensors
N = ctx.N
Expand Down Expand Up @@ -651,8 +652,13 @@ def backward(ctx, out_grad):


def cross_entropy_loss(
inp, target, weight=None, reduction=1, ignore_index=-100, label_smoothing=0.0
inp, target, weight=None, reduction="mean", ignore_index=-100, label_smoothing=0.0
):
return CrossEntropyLoss.apply(
inp, target, weight, reduction, ignore_index, label_smoothing
)
inp,
target,
weight,
_Reduction.get_enum(reduction),
ignore_index,
label_smoothing,
)
77 changes: 77 additions & 0 deletions src/flag_gems/runtime/backend/_ascend/fused/fused_add_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import logging
import math

import triton
import triton.language as tl

from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry
from flag_gems.utils import triton_lang_extension as tle

logger = logging.getLogger(__name__)

@libentry()
@triton.jit(do_not_specialize=["eps"])
def fused_add_rms_norm_kernel(
X, # pointer to the input
R, # pointer to the residual
W, # pointer to the weight
x_stride_r, # how much to increase the pointer when moving by 1 row
x_stride_c, # how much to increase the pointer when moving by 1 col
r_stride_r, # how much to increase the pointer when moving by 1 row
r_stride_c, # how much to increase the pointer when moving by 1 col
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
X += pid * x_stride_r
R += pid * r_stride_r

_var_base = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
x += r
_var_base += x * x / N
var = tl.sum(_var_base)

rrms = 1 / tl.sqrt(var + eps)

for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols, mask, other=0.0).to(tl.float32)
r = tl.load(R + cols, mask, other=0.0).to(tl.float32)
x += r
w = tl.load(W + cols, mask, other=0.0)
y = (x * rrms).to(X.dtype.element_ty) * w
# write back to residual and input
tl.store(R + cols * r_stride_c, x, mask=mask)
tl.store(X + cols * x_stride_c, y, mask=mask)


def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5):
"""
This function performs fused residual addition and RMS normalization **in-place**.
Both `x` and `residual` tensors will be modified. Use with caution if these tensors
are reused elsewhere or require gradients.
"""
logger.debug("GEMS_ASCEND FUSED_ADD_RMS_NORM FORWARD")
dim = x.ndim - len(normalized_shape)
M = min(math.prod(x.shape[:dim]), 65535)
N = math.prod(normalized_shape)

BLOCK_SIZE = min(triton.next_power_of_2(N), 8192)
x = x.contiguous()
residual = residual.contiguous()
weight = weight.contiguous()

with torch_device_fn.device(x.device):
fused_add_rms_norm_kernel[M,](
x, residual, weight, N, 1, N, 1, N, eps, BLOCK_SIZE
)
return x, residual
Loading