Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8"
CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d"
CompilerCaching = "9db33cc3-5358-4881-8759-fa4194144afd"
IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
10 changes: 8 additions & 2 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using CompilerCaching: CacheView, method_instance, results

import Core.Compiler as CC

using CUDA: CuModule, CuFunction, cudacall, device, capability
using CUDA: CUDA, CuModule, CuFunction, cudacall, device, capability
using CUDA_Compiler_jll

public launch
Expand Down Expand Up @@ -53,14 +53,18 @@ function check_tile_ir_support()
return VersionNumber(cuda_ver.major, cuda_ver.minor)
end

const EMIT_CODE_LOCK = ReentrantLock()

"""
emit_binary(cache, mi; const_argtypes=nothing) -> Vector{UInt8}

Binary phase: compile Tile IR bytecode to CUBIN using tileiras.
"""
function emit_binary(cache::CacheView, mi::Core.MethodInstance;
const_argtypes::Union{Vector{Any}, Nothing}=nothing)
bytecode = emit_code(cache, mi; const_argtypes)
bytecode = lock(EMIT_CODE_LOCK) do
emit_code(cache, mi; const_argtypes)
end

ci = get(cache, mi)
res = const_argtypes !== nothing ? results(cache, ci, const_argtypes) : results(cache, ci)
Expand Down Expand Up @@ -260,4 +264,6 @@ Other values pass through unchanged.
to_tile_arg(x) = x
to_tile_arg(arr::AbstractArray) = TileArray(arr)

include("autotune/autotune.jl")

end
300 changes: 300 additions & 0 deletions ext/autotune/autotune.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
import cuTile.Experimental: autotune_launch, clear_autotune_cache
using cuTile.Experimental: AbstractSearchSpace, CartesianSpace, FixedSpace

using Random

const AUTOTUNE_LOCK = ReentrantLock()
const AUTOTUNE_CACHE = Dict{Any, Dict{Any, Any}}()

struct VerificationError <: Exception
msg::String
end

const TUNING_PRESETS = (
fast = (warmup=1, reps=3, refine_topk=0, refine_reps=2),
default = (warmup=2, reps=5, refine_topk=2, refine_reps=4),
thorough = (warmup=2, reps=7, refine_topk=4, refine_reps=6),
)

function normalize_tuning(tuning::NamedTuple)
preset = get(tuning, :preset, :default)
preset isa Symbol || throw(ArgumentError("tuning.preset must be a Symbol"))
hasproperty(TUNING_PRESETS, preset) ||
throw(ArgumentError("Unknown preset `$preset`; use :fast, :default, or :thorough"))

base = merge(getproperty(TUNING_PRESETS, preset),
(seed=nothing, force=false, precompile_workers=Threads.nthreads()))

# Apply user overrides (excluding :preset)
overrides = NamedTuple(k => v for (k, v) in pairs(tuning) if k !== :preset)
return merge(base, overrides)
end

# Extract hint fields (occupancy, num_ctas) from a config for launch kwargs
function hints_from_cfg(cfg)
n = hasproperty(cfg, :num_ctas) ? cfg.num_ctas : nothing
o = hasproperty(cfg, :occupancy) ? cfg.occupancy : nothing
return (num_ctas=n, occupancy=o)
end

function time_ms(run_once::Function, get_args::Function;
warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing,
reset::Union{Nothing, Function}=nothing)
CUDA.synchronize()
for _ in 1:max(warmup, verify !== nothing ? 1 : 0)
reset !== nothing && reset()
run_once(get_args())
end

if verify !== nothing
CUDA.synchronize()
verify() || throw(VerificationError("config produced incorrect output"))
end

best_ms = Inf32
for _ in 1:reps
reset !== nothing && reset()
args = get_args()
CUDA.synchronize()
elapsed_s = CUDA.@elapsed run_once(args)
CUDA.synchronize()
best_ms = min(best_ms, Float32(elapsed_s * 1000))
end
return best_ms
end

function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function;
sm_arch::String, opt_level::Int, warmup::Int, reps::Int,
verify::Union{Nothing, Function}=nothing,
reset::Union{Nothing, Function}=nothing)
run_once = args -> cuTile.launch(f, grid_fn(cfg), args...;
sm_arch, opt_level, hints_from_cfg(cfg)...)
return time_ms(run_once, () -> args_fn(cfg); warmup, reps, verify, reset)
end

function precompile_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function;
sm_arch::String, opt_level::Int)
grid_fn(cfg)
args = args_fn(cfg)
tile_args = map(to_tile_arg, args)

# Mirror launch's Constant handling
unwrapped_types = map(tile_args) do arg
arg isa Constant ? constant_eltype(typeof(arg)) : typeof(arg)
end
argtypes = Tuple{unwrapped_types...}

world = Base.get_world_counter()
mi = method_instance(f, argtypes; world)
mi === nothing && throw(MethodError(f, argtypes))

has_consts = any(x -> x isa Constant, tile_args)
const_argtypes = if has_consts
cats = Any[CC.Const(f)]
for arg in tile_args
push!(cats, arg isa Constant ? CC.Const(arg[]) : typeof(arg))
end
cats
else
nothing
end

hints = hints_from_cfg(cfg)
bytecode_version = check_tile_ir_support()
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=hints.num_ctas, occupancy=hints.occupancy,
bytecode_version=bytecode_version)
cache = CacheView{CuTileResults}((:cuTile, opts), world)
emit_function(cache, mi; const_argtypes)
end

function precompile_candidates(@nospecialize(f), configs::Vector{Any},
grid_fn::Function, args_fn::Function;
sm_arch::String, opt_level::Int, workers::Int)
isempty(configs) && return configs, nothing
iszero(workers) && return configs, nothing

workers = min(workers, Threads.nthreads(), length(configs))
compiled = fill(true, length(configs))
errors = Vector{Any}(nothing, length(configs))
sem = Base.Semaphore(workers)
cancelled = Threads.Atomic{Bool}(false)

try
@sync for (i, cfg) in enumerate(configs)
Threads.@spawn begin
cancelled[] && return
Base.acquire(sem) do
cancelled[] && return
try
precompile_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level)
catch err
compiled[i] = false
errors[i] = (cfg, err)
end
end
end
end
catch e
cancelled[] = true
e isa InterruptException || rethrow()
@warn "Precompilation interrupted, waiting for in-flight workers…"
# @sync already waits for spawned tasks before propagating,
# but the atomic flag ensures queued ones exit early.
rethrow()
end

first_err = nothing
for e in errors
if e !== nothing
first_err = e
break
end
end

return configs[compiled], first_err
end

function measure_candidates(@nospecialize(f), configs::Vector{Any},
grid_fn::Function, args_fn::Function;
sm_arch::String, opt_level::Int, warmup::Int, reps::Int,
verify::Union{Nothing, Function}=nothing,
reset::Union{Nothing, Function}=nothing)
record = Tuple{Any, Float32}[]
first_error = nothing
for cfg in configs
ms = try
eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, verify, reset)
catch err
if err isa InterruptException
@warn "Benchmarking interrupted after $(length(record)) configs"
break
end
err isa VerificationError && @warn "Config $cfg failed verification, skipping"
first_error === nothing && (first_error = (cfg, err))
continue
end
push!(record, (cfg, ms))
end
return record, first_error
end

function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG,
grid_fn::Function, args_fn::Function, tuning;
sm_arch::String, opt_level::Int, kernel_key, arg_key,
verify::Union{Nothing, Function}=nothing,
setup::Union{Nothing, Function}=nothing)
if !tuning.force
entry = lock(AUTOTUNE_LOCK) do
per_kernel = get(AUTOTUNE_CACHE, kernel_key, nothing)
per_kernel !== nothing ? get(per_kernel, arg_key, nothing) : nothing
end
entry !== nothing && return entry, true, nothing
end

checker = verify !== nothing ? verify() : nothing
reset = setup !== nothing ? setup() : nothing

trials = collect(space)

trials = Any[trials...]
trials, precompile_error = Base.ScopedValues.with(cuTile._SCOPED_INF_CACHE => CC.InferenceResult[]) do
precompile_candidates(f, trials, grid_fn, args_fn;
sm_arch, opt_level, workers=tuning.precompile_workers)
end

record, first_error = measure_candidates(f, trials, grid_fn, args_fn;
sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, verify=checker, reset)

if isempty(record)
# Prefer showing the precompile error (more informative) over the benchmark error
err_info = first_error !== nothing ? first_error : precompile_error
if err_info === nothing
throw(ArgumentError("No valid config found in search space"))
else
cfg, err = err_info
throw(ArgumentError(
"No valid config found. First failure for cfg=$cfg: $(sprint(showerror, err))"))
end
end

# Refinement: re-benchmark top K with more reps to stabilize the winner
if tuning.refine_topk > 0 && length(record) > 1
sort!(record, by=last)
top_configs = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]]
refined, _ = measure_candidates(f, top_configs, grid_fn, args_fn;
sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, reset)
if !isempty(refined)
record = refined
end
end

_, best_idx = findmin(last, record)
candidate = (; best_config=record[best_idx][1], tuning_record=record)

entry, cache_hit = lock(AUTOTUNE_LOCK) do
per_kernel = get!(Dict{Any,Any}, AUTOTUNE_CACHE, kernel_key)
if !tuning.force && haskey(per_kernel, arg_key)
per_kernel[arg_key], true
else
per_kernel[arg_key] = candidate
candidate, false
end
end
return entry, cache_hit, reset
end

function autotune_launch(@nospecialize(f), space::AbstractSearchSpace,
grid_fn::Function, args_fn::Function;
key=nothing,
key_fn::Union{Nothing, Function}=nothing,
launch_args_fn::Union{Nothing, Function}=nothing,
verify::Union{Nothing, Function}=nothing,
setup::Union{Nothing, Function}=nothing,
tuning::NamedTuple=NamedTuple(),
sm_arch::String=default_sm_arch(),
opt_level::Int=3)
tuning = normalize_tuning(tuning)
rng = tuning.seed !== nothing ? MersenneTwister(tuning.seed) : Random.default_rng()

kernel_key = (f, sm_arch, opt_level)
arg_key = key !== nothing ? key : (key_fn !== nothing ? key_fn() : nothing)

entry, cache_hit, reset = find_or_tune(f, space, rng, grid_fn, args_fn, tuning;
sm_arch, opt_level, kernel_key, arg_key, verify, setup)

cfg = entry.best_config
grid = grid_fn(cfg)
args = launch_args_fn !== nothing ? launch_args_fn(cfg) : args_fn(cfg)

# Reset state before the final "real" launch
reset !== nothing && reset()

cuTile.launch(f, grid, args...; sm_arch, opt_level, hints_from_cfg(cfg)...)

return (; tuned_config=cfg, grid, tuning_record=copy(entry.tuning_record), cache_hit)
end

# Convenience: accept plain Vector (→ FixedSpace) or NamedTuple (→ CartesianSpace)
function autotune_launch(@nospecialize(f), configs, grid_fn::Function, args_fn::Function; kwargs...)
space = configs isa NamedTuple ? CartesianSpace(configs) : FixedSpace(configs)
return autotune_launch(f, space, grid_fn, args_fn; kwargs...)
end

function clear_autotune_cache(; kernel=nothing, key=nothing)
lock(AUTOTUNE_LOCK) do
if kernel === nothing
key === nothing || throw(ArgumentError("`key` requires `kernel`"))
empty!(AUTOTUNE_CACHE)
return nothing
end

for kernel_key in collect(keys(AUTOTUNE_CACHE))
kernel_key isa Tuple || continue
kernel_key[1] === kernel || continue
per_kernel = AUTOTUNE_CACHE[kernel_key]
key === nothing ? empty!(per_kernel) : pop!(per_kernel, key, nothing)
isempty(per_kernel) && delete!(AUTOTUNE_CACHE, kernel_key)
end
end
return nothing
end
40 changes: 40 additions & 0 deletions src/Experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module Experimental

autotune_launch(args...; kwargs...) =
error("Please import CUDA.jl before using `cuTile.autotune_launch`.")
clear_autotune_cache(args...; kwargs...) =
error("Please import CUDA.jl before using `cuTile.clear_autotune_cache`.")

abstract type AbstractSearchSpace end

Base.length(s::AbstractSearchSpace) = count(_ -> true, s)

struct FixedSpace{names,NT<:NamedTuple{names}} <: AbstractSearchSpace
elements::Vector{NT}
end

Base.iterate(space::FixedSpace, args...) = iterate(space.elements, args...)

struct CartesianSpace{names,NT<:NamedTuple{names,<:Tuple{Vararg{Tuple}}}} <: AbstractSearchSpace
constraint::Function
axes::NT
end

CartesianSpace(axes::NamedTuple) = CartesianSpace(Returns(true), axes)
CartesianSpace(; axes...) = CartesianSpace(NamedTuple(axes))
CartesianSpace(constraint::Function; axes...) = CartesianSpace(constraint, NamedTuple(axes))

function Base.iterate(space::CartesianSpace{names}, state=nothing) where names
to_cfg = vals -> NamedTuple{names}(vals)
inner = state === nothing ?
Iterators.filter(space.constraint ∘ to_cfg,
Iterators.product(map(Tuple, values(space.axes))...)) :
state.inner
result = isnothing(state) ? iterate(inner) : iterate(inner, state.cursor)
isnothing(result) && return nothing
vals, cursor = result
cfg = to_cfg(vals)
return cfg, (; inner, cursor)
end

end
Loading
Loading