Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 233 additions & 39 deletions src/flag_gems/ops/mean.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import math
from functools import reduce

import torch
import triton
Expand All @@ -21,25 +22,43 @@ def mean_kernel_1(
M,
BLOCK_SIZE: tl.constexpr,
):
# accumulation dtype
if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
inp.dtype.element_ty == tl.bfloat16
):
cdtype = tl.float32
else:
cdtype = inp.dtype.element_ty

pid = tle.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=0.0)
sum_val = tl.sum(inp_val, axis=0)

inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
sum_val = tl.sum(inp_val)
mid_ptr = mid + pid
tl.store(mid_ptr, sum_val)


@libentry()
@triton.jit
def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr):
if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
mid.dtype.element_ty == tl.bfloat16
):
cdtype = tl.float32
else:
cdtype = mid.dtype.element_ty

offset = tl.arange(0, BLOCK_MID)
mid_ptrs = mid + offset
mask = offset < MID_SIZE
mid_val = tl.load(mid_ptrs, mask=mask, other=0.0)
sum_val = tl.sum(mid_val, axis=0) / M
tl.store(out, sum_val)
mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
sum_val = tl.sum(mid_val)
# divide by total element count M to get mean
mean_val = sum_val / M
tl.store(out, mean_val)


def mean(inp, *, dtype=None):
Expand All @@ -60,57 +79,232 @@ def mean(inp, *, dtype=None):
return out


@libentry()
@triton.heuristics(runtime.get_heuristic_config("mean_non_inner"))
@triton.jit
def mean_dim_kernel_non_inner(
output_ptr,
input_ptr,
M,
N,
K,
TILE_N: tl.constexpr,
TILE_K: tl.constexpr,
ONE_TILE_PER_CTA: tl.constexpr,
):
# accumulation dtype
if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
input_ptr.dtype.element_ty == tl.bfloat16
):
cdtype = tl.float32
else:
cdtype = input_ptr.dtype.element_ty

pid_m = tle.program_id(0)
pid_k = tle.program_id(1)

k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :]

if ONE_TILE_PER_CTA:
n_offsets = tl.arange(0, TILE_N)[:, None]
inp_offset = pid_m * N * K + n_offsets * K + k_offsets
mask = (n_offsets < N) & (k_offsets < K)
input_ptrs = input_ptr + inp_offset
inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
# sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K
summed = tl.sum(inp, axis=0, keep_dims=True)
# divide by N to get mean
out = summed / N
out_offset = pid_m * K + k_offsets
output_ptrs = output_ptr + out_offset
tl.store(output_ptrs, out, mask=k_offsets < K)
else:
sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype)
for start_n in range(0, N, TILE_N):
n_offsets = start_n + tl.arange(0, TILE_N)[:, None]
inp_offsets = pid_m * N * K + n_offsets * K + k_offsets
mask = (n_offsets < N) & (k_offsets < K)
inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
sum_tile += inp
summed = tl.sum(sum_tile, axis=0, keep_dims=True)
out = summed / N
out_offset = pid_m * K + k_offsets
output_ptrs = output_ptr + out_offset
tl.store(output_ptrs, out, mask=k_offsets < K)


@libentry()
@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
@triton.jit
def mean_dim_kernel_inner(
output_ptr,
input_ptr,
M,
N,
TILE_N: tl.constexpr,
ONE_TILE_PER_CTA: tl.constexpr,
):
if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr(
input_ptr.dtype.element_ty == tl.bfloat16
):
cdtype = tl.float32
else:
cdtype = input_ptr.dtype.element_ty

pid_m = tle.program_id(0)
if ONE_TILE_PER_CTA:
n_offsets = tl.arange(0, TILE_N)
inp_offset = pid_m * N + n_offsets
input_ptrs = input_ptr + inp_offset
mask = n_offsets < N
inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype)
summed = tl.sum(inp, axis=0)
out = summed / N
out_offset = pid_m
output_ptrs = output_ptr + out_offset
tl.store(output_ptrs, out)
else:
sum_vec = tl.zeros(
[
TILE_N,
],
dtype=cdtype,
)
for start_n in range(0, N, TILE_N):
n_offsets = start_n + tl.arange(0, TILE_N)
inp_offsets = pid_m * N + n_offsets
mask = n_offsets < N
inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype)
sum_vec += inp
summed = tl.sum(sum_vec, axis=0)
out = summed / N
out_offset = pid_m
output_ptrs = output_ptr + out_offset
tl.store(output_ptrs, out)


@libentry()
@libtuner(
configs=runtime.get_tuned_config("naive_reduction"),
key=["M", "N"],
)
@triton.jit
def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
# Map the program id to the row of X it should compute.
def mean_dim_kernel(
inp,
out,
M,
N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
inp.dtype.element_ty == tl.bfloat16
):
cdtype = tl.float32
else:
cdtype = inp.dtype.element_ty

# Map the program id to the row of inp it should compute.
pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
X = X + pid * N
Mean = Mean + pid
inp = inp + pid * N
out = out + pid
row_mask = pid < M

# Compute mean
_mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
_sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
for off in range(0, N, BLOCK_N):
cols = off + tl.arange(0, BLOCK_N)[None, :]
col_mask = cols < N
mask = row_mask and col_mask

a = tl.load(X + cols, mask, other=0.0).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=1) / N
mean = mean[:, None]
tl.store(Mean, mean, row_mask)

a = tl.load(inp + cols, mask, other=0).to(cdtype)
_sum += a
summed = tl.sum(_sum, axis=1)[:, None]
mean = summed / N
tl.store(out, mean, row_mask)

def mean_dim(x, dim, keepdim=False, *, dtype=None):
logger.debug("GEMS MEAN DIM")

def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None):
logger.debug("GEMS MEAN_DIM")
if dtype is None:
dtype = x.dtype
if dim is None:
out = mean(x, dtype=dtype)
dtype = inp.dtype
if dtype is torch.bool:
inp = inp.to(torch.int64)
dtype = torch.int64

if dim == []:
# mean over all elements
if not keepdim:
return mean(inp, dtype=dtype)
else:
dim_num = inp.ndim
return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num)

shape = list(inp.shape)

# -------- normalize dim to a list of ints --------
if isinstance(dim, int):
dim = [dim]
else:
try:
dim = list(dim)
except TypeError:
raise TypeError(
f"dim must be an int, iterable of ints, or [], got {type(dim)}"
)

dim = [d % inp.ndim for d in dim]
# -------------------------------------------------

if len(dim) == 1:
dim0 = dim[0]
N = inp.shape[dim0] # reduction length
# product of dims before dim0; use initializer 1 for empty slice
M = reduce(lambda x, y: x * y, shape[:dim0], 1)
inp = inp.contiguous()
K = inp.numel() // M // N
shape[dim0] = 1
if out is None:
out = torch.empty(shape, dtype=dtype, device=inp.device)

with torch_device_fn.device(inp.device):
if K > 1:
grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
mean_dim_kernel_non_inner[grid](
out,
inp,
M,
N,
K,
)
else:
grid = (M, 1, 1)
mean_dim_kernel_inner[grid](
out,
inp,
M,
N,
)
if not keepdim:
out = out.reshape([1] * x.ndim)
out = out.squeeze(dim=dim0)
return out
else:
inp = dim_compress(inp, dim)
N = 1
for i in dim:
N *= shape[i]
shape[i] = 1
M = inp.numel() // N
if out is None:
out = torch.empty(shape, dtype=dtype, device=inp.device)

shape = list(x.shape)
dim = [d % x.ndim for d in dim]
x = dim_compress(x, dim)
N = 1
for i in dim:
N *= shape[i]
shape[i] = 1
M = x.numel() // N
out = torch.empty(shape, dtype=dtype, device=x.device)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)

with torch_device_fn.device(x.device):
mean_dim_kernel[grid](x, out, M, N)
if not keepdim:
out = out.squeeze(dim)
return out
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
with torch_device_fn.device(inp.device):
mean_dim_kernel[grid](inp, out, M, N)
if not keepdim:
out = out.squeeze(dim=dim)
return out


def mean_dim(inp, dim=None, keepdim=False, *, dtype=None):
logger.debug("GEMS MEAN_DIM (wrapper)")
return mean_dim_comm(inp, dim, keepdim, dtype=dtype)
Comment on lines +226 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For simplicity and clarity, consider combining mean_dim_comm and the mean_dim wrapper into a single mean_dim function. The current implementation with a simple wrapper adds an unnecessary layer of indirection. Exposing the out parameter in the public mean_dim function is also consistent with the PyTorch API.

def mean_dim(inp, dim=None, keepdim=False, *, dtype=None, out=None):
    logger.debug("GEMS MEAN_DIM")
    if dtype is None:
        dtype = inp.dtype
        if dtype is torch.bool:
            inp = inp.to(torch.int64)
            dtype = torch.int64

    if dim == []:
        # mean over all elements
        if not keepdim:
            return mean(inp, dtype=dtype)
        else:
            dim_num = inp.ndim
            return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num)

    shape = list(inp.shape)

    # -------- normalize dim to a list of ints --------
    if isinstance(dim, int):
        dim = [dim]
    else:
        try:
            dim = list(dim)
        except TypeError:
            raise TypeError(
                f"dim must be an int, iterable of ints, or [], got {type(dim)}"
            )

    dim = [d % inp.ndim for d in dim]
    # -------------------------------------------------

    if len(dim) == 1:
        dim0 = dim[0]
        N = inp.shape[dim0]  # reduction length
        # product of dims before dim0; use initializer 1 for empty slice
        M = reduce(lambda x, y: x * y, shape[:dim0], 1)
        inp = inp.contiguous()
        K = inp.numel() // M // N
        shape[dim0] = 1
        if out is None:
            out = torch.empty(shape, dtype=dtype, device=inp.device)

        with torch_device_fn.device(inp.device):
            if K > 1:
                grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1)
                mean_dim_kernel_non_inner[grid](
                    out,
                    inp,
                    M,
                    N,
                    K,
                )
            else:
                grid = (M, 1, 1)
                mean_dim_kernel_inner[grid](
                    out,
                    inp,
                    M,
                    N,
                )
        if not keepdim:
            out = out.squeeze(dim=dim0)
        return out
    else:
        inp = dim_compress(inp, dim)
        N = 1
        for i in dim:
            N *= shape[i]
            shape[i] = 1
        M = inp.numel() // N
        if out is None:
            out = torch.empty(shape, dtype=dtype, device=inp.device)

        grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
        with torch_device_fn.device(inp.device):
            mean_dim_kernel[grid](inp, out, M, N)
        if not keepdim:
            out = out.squeeze(dim=dim)
        return out

47 changes: 47 additions & 0 deletions src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import triton


_MIN_TILE_N = 64
_MAX_TILE_N_PER_ROW = 4096
_MAX_ONE_TILE_N = 2048


def simple_elementwise_blocksize_heur(args):
return 1024

Expand Down Expand Up @@ -232,6 +237,42 @@ def vdot_heur_block_size(args):
return 1024


def mean_heur_tile_k(args):
MAX_TILE_K = 512
NUM_SMS = torch.cuda.get_device_properties(
torch.cuda.current_device()
).multi_processor_count
tile_k = 1
upper_bound = min(args["K"], MAX_TILE_K)
max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N)
upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n)
while tile_k <= upper_bound:
num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
num_waves = num_blocks / NUM_SMS
if (num_waves > 1) and (tile_k * 2 <= upper_bound):
tile_k *= 2
else:
break
return tile_k


def mean_heur_tile_n_non_inner(args):
tile_k = args.get("TILE_K", 1)
limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k)
N = args.get("N", 1)
desired = min(max(N, _MIN_TILE_N), limit_by_k)
desired = min(desired, _MAX_ONE_TILE_N, limit_by_k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The limit_by_k variable in this min call is redundant. The value of desired is already constrained by limit_by_k in the previous line. Removing the redundant variable will make the code clearer.

Suggested change
desired = min(desired, _MAX_ONE_TILE_N, limit_by_k)
desired = min(desired, _MAX_ONE_TILE_N)

tile_n = triton.next_power_of_2(desired)
if tile_n > limit_by_k:
tile_n = limit_by_k
tile_n = max(tile_n, _MIN_TILE_N)
return tile_n


def mean_heur_one_tile_per_cta(args):
return args["TILE_N"] >= args["N"]


HEURISTICS_CONFIGS = {
"argmax": {
"BLOCK_M": argmax_heur_block_m,
Expand Down Expand Up @@ -279,6 +320,12 @@ def vdot_heur_block_size(args):
"ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
"num_warps": softmax_heur_num_warps_non_inner,
},
"mean_non_inner": {
"TILE_K": mean_heur_tile_k,
"TILE_N": mean_heur_tile_n_non_inner,
"ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta,
"num_warps": softmax_heur_num_warps_non_inner,
},
"softmax_inner": {
"TILE_N": softmax_heur_tile_n_inner,
"ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
Expand Down
Loading