Skip to content

Commit fb5d204

Browse files
committed
Add topk Triton kernel for CUDA backend
Add a Triton-based topk kernel that replaces aten.topk during graph transformation, compiled directly into the AOTInductor .so via wrap_triton (no C++ fallback shim needed). The kernel uses iterative argmax with masking, adapted from FlagGems/aiter. It is registered via @triton_op("triton::topk") and auto-substituted for aten.topk.default through ReplaceEdgeOpWithTritonOpPass. Tests follow the chunk_gated_delta_rule pattern: eager correctness across 8 configs, export validation, and E2E C++ runner comparison.
1 parent e458023 commit fb5d204

6 files changed

Lines changed: 613 additions & 0 deletions

File tree

backends/cuda/tests/test_topk.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Export and validate topk triton kernel on CUDA backend.
9+
10+
Usage:
11+
python -m pytest backends/cuda/tests/test_topk.py -v
12+
13+
# Standalone export (produces .pte + .ptd):
14+
python backends/cuda/tests/test_topk.py --output-dir /tmp/exports
15+
"""
16+
17+
import argparse
18+
import os
19+
import subprocess
20+
import sys
21+
import tempfile
22+
import unittest
23+
24+
import numpy as np
25+
import torch
26+
import torch.nn as nn
27+
from torch.export import export
28+
29+
from executorch.backends.cuda.cuda_backend import CudaBackend
30+
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
31+
from executorch.exir import (
32+
EdgeCompileConfig,
33+
ExecutorchBackendConfig,
34+
to_edge_transform_and_lower,
35+
)
36+
from executorch.exir.passes import MemoryPlanningPass
37+
38+
39+
RUNNER_PATH = os.path.join(
40+
os.path.dirname(__file__),
41+
"../../../cmake-out/backends/cuda/tests/topk_runner/topk_runner",
42+
)
43+
44+
# Test configurations: (seed, rows, cols, k, dim, largest, description)
45+
TEST_CONFIGS = [
46+
(42, 4, 8, 2, -1, True, "basic_4x8_k2"),
47+
(0, 1, 16, 3, -1, True, "single_row_k3"),
48+
(7, 8, 4, 1, -1, True, "8x4_k1"),
49+
(99, 4, 8, 2, -1, False, "smallest_k2"),
50+
(13, 2, 32, 5, -1, True, "wide_k5"),
51+
(55, 4, 8, 8, -1, True, "k_equals_n"),
52+
(77, 1, 4, 2, -1, True, "tiny_1x4_k2"),
53+
(123, 16, 8, 2, -1, True, "many_rows"),
54+
]
55+
56+
57+
class TopKModel(nn.Module):
58+
"""Linear projection followed by topk."""
59+
60+
def __init__(self, dim_in=8, k=2, topk_dim=-1, largest=True):
61+
super().__init__()
62+
self.linear = nn.Linear(dim_in, dim_in, bias=False)
63+
self.k = k
64+
self.topk_dim = topk_dim
65+
self.largest = largest
66+
67+
def forward(self, x):
68+
x = self.linear(x)
69+
values, indices = torch.topk(x, self.k, dim=self.topk_dim, largest=self.largest)
70+
return values, indices
71+
72+
73+
def _make_inputs(seed, rows, cols, dtype=torch.bfloat16, device="cuda"):
74+
torch.manual_seed(seed)
75+
return (torch.randn(rows, cols, dtype=dtype, device=device),)
76+
77+
78+
def _save_tensor(t, path):
79+
t_cpu = t.cpu().contiguous()
80+
with open(path, "wb") as f:
81+
f.write(bytes(t_cpu.untyped_storage()))
82+
83+
84+
def _load_output(path, shape, dtype):
85+
data = np.fromfile(path, dtype=np.uint8)
86+
return torch.frombuffer(bytearray(data), dtype=dtype).reshape(shape)
87+
88+
89+
def export_topk(output_dir, cols=8, k=2, largest=True):
90+
"""Export a TopKModel to .pte + .ptd."""
91+
torch.manual_seed(42)
92+
model = TopKModel(dim_in=cols, k=k, largest=largest).to(
93+
device="cuda", dtype=torch.bfloat16
94+
).eval()
95+
inputs = _make_inputs(42, 4, cols)
96+
97+
with torch.no_grad():
98+
ref_vals, ref_idx = model(*inputs)
99+
print(f"Eager output: values {ref_vals.shape}, indices {ref_idx.shape}")
100+
101+
with torch.no_grad():
102+
ep = export(model, inputs, strict=True)
103+
print("Export OK")
104+
105+
os.makedirs(output_dir, exist_ok=True)
106+
107+
specs = [CudaBackend.generate_method_name_compile_spec("forward")]
108+
et_prog = to_edge_transform_and_lower(
109+
ep,
110+
partitioner=[CudaPartitioner(specs)],
111+
compile_config=EdgeCompileConfig(
112+
_check_ir_validity=False, _skip_dim_order=True
113+
),
114+
)
115+
et_program = et_prog.to_executorch(
116+
config=ExecutorchBackendConfig(
117+
extract_delegate_segments=True,
118+
do_quant_fusion_and_const_prop=True,
119+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
120+
),
121+
)
122+
123+
pte_path = os.path.join(output_dir, "topk.pte")
124+
with open(pte_path, "wb") as f:
125+
f.write(et_program.buffer)
126+
127+
if hasattr(et_program, "_tensor_data") and et_program._tensor_data:
128+
et_program.write_tensor_data_to_file(output_dir)
129+
130+
print(f"Saved to {pte_path} ({os.path.getsize(pte_path) / 1024:.0f} KB)")
131+
return pte_path, model
132+
133+
134+
def _run_cpp_runner(runner_path, pte_path, ptd_path, input_dir, output_dir):
135+
cmd = [
136+
runner_path,
137+
f"--model_path={pte_path}",
138+
f"--data_path={ptd_path}",
139+
f"--input_dir={input_dir}",
140+
f"--output_dir={output_dir}",
141+
]
142+
return subprocess.run(cmd, capture_output=True, text=True)
143+
144+
145+
class TestTopK(unittest.TestCase):
146+
def setUp(self):
147+
if not torch.cuda.is_available():
148+
self.skipTest("CUDA is not available")
149+
150+
def test_eager(self):
151+
"""Triton topk produces correct shapes and dtypes."""
152+
model = TopKModel().to(device="cuda", dtype=torch.bfloat16).eval()
153+
inputs = _make_inputs(42, 4, 8)
154+
with torch.no_grad():
155+
vals, idx = model(*inputs)
156+
self.assertEqual(vals.shape, torch.Size([4, 2]))
157+
self.assertEqual(idx.shape, torch.Size([4, 2]))
158+
self.assertEqual(vals.dtype, torch.bfloat16)
159+
self.assertEqual(idx.dtype, torch.int64)
160+
161+
def test_eager_correctness(self):
162+
"""Triton topk matches torch.topk across multiple configs."""
163+
for seed, rows, cols, k, dim, largest, desc in TEST_CONFIGS:
164+
with self.subTest(desc=desc):
165+
torch.manual_seed(seed)
166+
x = torch.randn(rows, cols, dtype=torch.bfloat16, device="cuda")
167+
168+
ref_vals, ref_idx = torch.topk(x, k, dim=dim, largest=largest)
169+
170+
from executorch.backends.cuda.triton.kernels.topk import topk as triton_topk
171+
172+
tri_vals, tri_idx = triton_topk(x, k, dim=dim, largest=largest)
173+
174+
v_diff = (tri_vals.float() - ref_vals.float()).abs().max().item()
175+
self.assertLess(v_diff, 1e-3, f"{desc}: value diff {v_diff}")
176+
self.assertTrue(
177+
torch.equal(tri_idx, ref_idx),
178+
f"{desc}: indices mismatch",
179+
)
180+
181+
def test_export_cuda(self):
182+
"""Export succeeds and produces non-empty .pte."""
183+
with tempfile.TemporaryDirectory() as tmpdir:
184+
pte_path, _ = export_topk(tmpdir)
185+
self.assertTrue(os.path.exists(pte_path))
186+
self.assertGreater(os.path.getsize(pte_path), 0)
187+
188+
@unittest.skipUnless(os.path.exists(RUNNER_PATH), "C++ runner not built")
189+
def test_e2e_cpp_runner(self):
190+
"""Export, run C++ runner, compare with eager."""
191+
with tempfile.TemporaryDirectory() as tmpdir:
192+
export_dir = os.path.join(tmpdir, "export")
193+
pte_path, model = export_topk(export_dir)
194+
ptd_path = os.path.join(export_dir, "aoti_cuda_blob.ptd")
195+
196+
for seed, rows, cols, k, dim, largest, desc in TEST_CONFIGS:
197+
# Skip configs that don't match the exported model shape
198+
if cols != 8 or k != 2 or not largest or rows != 4:
199+
continue
200+
201+
with self.subTest(desc=desc):
202+
inputs = _make_inputs(seed, rows, cols)
203+
204+
with torch.no_grad():
205+
ref_vals, ref_idx = model(*inputs)
206+
207+
input_dir = os.path.join(tmpdir, f"inputs_{desc}")
208+
output_dir = os.path.join(tmpdir, f"outputs_{desc}")
209+
os.makedirs(input_dir)
210+
os.makedirs(output_dir)
211+
212+
_save_tensor(inputs[0], os.path.join(input_dir, "x.bin"))
213+
214+
result = _run_cpp_runner(
215+
RUNNER_PATH, pte_path, ptd_path, input_dir, output_dir
216+
)
217+
self.assertEqual(
218+
result.returncode,
219+
0,
220+
f"{desc}: C++ runner failed:\n{result.stderr}",
221+
)
222+
223+
cpp_vals = _load_output(
224+
os.path.join(output_dir, "output_0.bin"),
225+
(rows, k),
226+
torch.bfloat16,
227+
)
228+
cpp_idx = _load_output(
229+
os.path.join(output_dir, "output_1.bin"),
230+
(rows, k),
231+
torch.int64,
232+
)
233+
234+
v_diff = (
235+
(cpp_vals.float() - ref_vals.cpu().float()).abs().max().item()
236+
)
237+
self.assertLess(v_diff, 0.01, f"{desc}: value diff {v_diff}")
238+
self.assertTrue(
239+
torch.equal(cpp_idx, ref_idx.cpu()),
240+
f"{desc}: indices mismatch\n"
241+
f" cpp: {cpp_idx}\n ref: {ref_idx.cpu()}",
242+
)
243+
244+
245+
if __name__ == "__main__":
246+
parser = argparse.ArgumentParser()
247+
parser.add_argument("--output-dir", default=None)
248+
args, remaining = parser.parse_known_args()
249+
250+
if args.output_dir:
251+
export_topk(args.output_dir)
252+
else:
253+
sys.argv = [sys.argv[0]] + remaining
254+
unittest.main()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
cmake_minimum_required(VERSION 3.24)
2+
project(topk_runner)
3+
4+
set(CMAKE_CXX_STANDARD 17)
5+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
6+
7+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
8+
9+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
10+
11+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
12+
13+
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../../third-party/gflags)
14+
find_package(gflags REQUIRED)
15+
16+
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../../..)
17+
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
18+
executorch_target_link_options_shared_lib(executorch)
19+
20+
set(link_libraries executorch gflags)
21+
22+
list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas)
23+
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)
24+
25+
list(
26+
APPEND
27+
link_libraries
28+
extension_module
29+
extension_data_loader
30+
extension_tensor
31+
extension_flat_tensor
32+
extension_named_data_map
33+
)
34+
35+
if(EXECUTORCH_BUILD_CUDA)
36+
find_package(CUDAToolkit REQUIRED)
37+
list(APPEND link_libraries aoti_cuda_backend)
38+
if(NOT MSVC)
39+
executorch_target_link_options_shared_lib(aoti_cuda_backend)
40+
endif()
41+
endif()
42+
43+
add_executable(topk_runner main.cpp)
44+
target_include_directories(
45+
topk_runner PUBLIC ${_common_include_directories}
46+
)
47+
target_link_libraries(topk_runner PUBLIC ${link_libraries})
48+
49+
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
50+
target_link_options_gc_sections(topk_runner)
51+
if(NOT APPLE AND NOT MSVC)
52+
target_link_options(topk_runner PRIVATE "LINKER:-s")
53+
endif()
54+
endif()

0 commit comments

Comments
 (0)