Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 127 additions & 35 deletions src/flag_gems/ops/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,46 +52,130 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr


@libentry()
@triton.heuristics(runtime.get_heuristic_config("argmax"))
@triton.heuristics(runtime.get_heuristic_config("argmax_non_inner"))
@triton.jit
def argmax_kernel(
def argmax_kernel_non_inner(
inp,
out_index,
M,
N,
K,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
TILE_K: tl.constexpr,
TILE_N: tl.constexpr,
ONE_TILE_PER_CTA: tl.constexpr,
):
# set offset
pid_m = tle.program_id(0)
pid_k = tle.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)

if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
inp.dtype.element_ty == tl.bfloat16
):
cdtype = tl.float32
else:
cdtype = inp.dtype.element_ty

min_value = get_dtype_min(cdtype)

if ONE_TILE_PER_CTA:
n_offset = tl.arange(0, TILE_N)
offset = pid_m * N * K + n_offset[:, None] * K + k_offset
mask = k_offset < K and n_offset[:, None] < N
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
local_max, local_argmax = tl.max(
inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
)
offset_index = pid_m * K + k_offset
out_index_ptrs = out_index + offset_index
mask1 = k_offset < K
tl.store(out_index_ptrs, local_argmax, mask=mask1)
else:
max_values = tl.full([TILE_K], dtype=cdtype, value=min_value)
argmax_values = tl.full([TILE_K], dtype=tl.int64, value=0)

for start_n in range(0, N, TILE_N):
n_offset = start_n + tl.arange(0, TILE_N)
offset = pid_m * N * K + n_offset[:, None] * K + k_offset
mask = k_offset < K and n_offset[:, None] < N
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
local_max, local_argmax = tl.max(
inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
)
update = local_max > max_values
max_values = tl.where(update, local_max, max_values)
argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
offset_index = pid_m * K + k_offset
out_index_ptrs = out_index + offset_index
mask1 = k_offset < K
tl.store(out_index_ptrs, argmax_values, mask=mask1)


@libentry()
@triton.heuristics(runtime.get_heuristic_config("argmax_inner"))
@triton.jit
def argmax_kernel_inner(
inp,
out_index,
M,
N,
TILE_N: tl.constexpr,
ONE_TILE_PER_CTA: tl.constexpr,
):
pid_m = tle.program_id(0)

dtype = inp.type.element_ty
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
min_value = get_dtype_min(dtype)
max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value)
argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
for start_n in range(0, N, BLOCK_N):
n_offset = start_n + tl.arange(0, BLOCK_N)
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
mask = m_offset[:, None] < M and n_offset[None, :] < N

if ONE_TILE_PER_CTA:
n_offset = tl.arange(0, TILE_N)
offset = pid_m * N + n_offset
mask = n_offset < N
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
local_max, local_argmax = tl.max(
inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
)
# if return indices is not supported, call a tl.argmax in addition
# local_argmax = tl.argmax(inp_vals, 1)
update = local_max > max_values
max_values = tl.where(update, local_max, max_values)
argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
out_index_ptrs = out_index + pid_m
tl.store(out_index_ptrs, local_argmax)
else:
max_values = min_value
argmax_values = 0

loop_time = N // TILE_N
remainder = N % TILE_N
for start_n in range(0, loop_time):
n_offset = start_n * TILE_N + tl.arange(0, TILE_N)
offset = pid_m * N + n_offset
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs)
local_max, local_argmax = tl.max(
inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
)
update = local_max > max_values
max_values = tl.where(update, local_max, max_values)
argmax_values = tl.where(
update, start_n * TILE_N + local_argmax, argmax_values
)

offset_index = m_offset * K + pid_k
out_index_ptrs = out_index + offset_index
mask1 = m_offset < M
tl.store(out_index_ptrs, argmax_values, mask=mask1)
if remainder:
n_offset = loop_time * TILE_N + tl.arange(0, TILE_N)
offset = pid_m * N + n_offset
mask = n_offset < N
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
local_max, local_argmax = tl.max(
inp_vals, 0, return_indices=True, return_indices_tie_break_left=True
)
update = local_max > max_values
max_values = tl.where(update, local_max, max_values)
argmax_values = tl.where(
update, loop_time * TILE_N + local_argmax, argmax_values
)

out_index_ptrs = out_index + pid_m
tl.store(out_index_ptrs, argmax_values)


def argmax(inp, dim=None, keepdim=False, *, dtype=None):
Expand Down Expand Up @@ -140,17 +224,25 @@ def argmax(inp, dim=None, keepdim=False, *, dtype=None):
if not keepdim:
out_index = torch.squeeze(out_index, dim)

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
K,
)
with torch_device_fn.device(inp.device):
argmax_kernel[grid](
inp,
out_index,
M,
N,
K,
)

if K > 1:
grid = lambda meta: (
M,
triton.cdiv(K, meta["TILE_K"]),
)
argmax_kernel_non_inner[grid](
inp,
out_index,
M,
N,
K,
)
else:
grid = lambda meta: (M, 1, 1)
argmax_kernel_inner[grid](
inp,
out_index,
M,
N,
)
return out_index
101 changes: 94 additions & 7 deletions src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,92 @@ def simple_elementwise_blocksize_heur(args):
return 1024


def argmax_heur_block_m(args):
return 4 if args["M"] < 4096 else 8
def argmax_heur_tile_k(args):
MAX_TILE_K = 512
NUM_SMS = torch.cuda.get_device_properties(
torch.cuda.current_device()
).multi_processor_count

K = args["K"]
M = args["M"]
dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"

def argmax_heur_block_n(args):
return min(4096, triton.next_power_of_2(args["N"]))
if M == 64 and K == 512:
return 64 if dtype == "fp32" else 128

if K <= 128:
return 1 << (K.bit_length() - 1) if K > 0 else 1

tile_k = 64
upper_bound = min(K, MAX_TILE_K)

while tile_k <= upper_bound:
num_blocks = M * triton.cdiv(K, tile_k)
num_waves = num_blocks / NUM_SMS

if num_waves > 1 and (tile_k * 2 <= upper_bound):
tile_k *= 2
else:
break

return tile_k


def argmax_heur_tile_n_non_inner(args):
n = args["N"]
tile_k = args["TILE_K"]

if n <= 128:
return n

target_tile = min(8192, n)
tile_n = triton.next_power_of_2(target_tile)
tile_n = max(64, min(tile_n, 4096))

if tile_n * tile_k > 32768:
tile_n = max(64, 32768 // tile_k)

return tile_n


def argmax_heur_one_tile_per_cta(args):
return args["TILE_N"] >= args["N"]


def argmax_heur_num_warps_non_inner(args):
tile_n = args["TILE_N"]
dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"

if tile_n <= 32:
num_warps = 2
elif tile_n <= 64:
num_warps = 4
elif tile_n <= 128:
num_warps = 4
else:
num_warps = 8

if dtype == "fp32":
num_warps = min(num_warps, 4)

return num_warps


def argmax_heur_tile_n_inner(args):
if args["N"] <= (32 * 1024):
return triton.next_power_of_2(args["N"])
else:
return 4096


def argmax_heur_num_warps_inner(args):
tile_size = args["TILE_N"]
if tile_size < 2048:
return 4
elif tile_size < 4096:
return 8
else:
return 16


def argmin_heur_block_m(args):
Expand Down Expand Up @@ -233,9 +313,16 @@ def vdot_heur_block_size(args):


HEURISTICS_CONFIGS = {
"argmax": {
"BLOCK_M": argmax_heur_block_m,
"BLOCK_N": argmax_heur_block_n,
"argmax_non_inner": {
"TILE_K": argmax_heur_tile_k,
"TILE_N": argmax_heur_tile_n_non_inner,
"ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
"num_warps": argmax_heur_num_warps_non_inner,
},
"argmax_inner": {
"TILE_N": argmax_heur_tile_n_inner,
"ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
"num_warps": argmax_heur_num_warps_inner,
},
"argmin": {
"BLOCK_M": argmin_heur_block_m,
Expand Down
Loading