Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ steps:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
coverage: false
commands: |
unset LD_LIBRARY_PATH
agents:
queue: "juliagpu"
cuda: "*"
gpu: "a100"
timeout_in_minutes: 90
timeout_in_minutes: 15
matrix:
setup:
julia:
- "1.11"
- "1.12"
- "1.13"
43 changes: 37 additions & 6 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module CUDAExt

using cuTile
using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code, sanitize_name,
constant_eltype, constant_value, is_ghost_type
using cuTile: TileArray, Constant, CGOpts, CuTileResults, DEFAULT_BYTECODE_VERSION,
emit_code, sanitize_name, constant_eltype, constant_value, is_ghost_type

using CompilerCaching: CacheView, method_instance, results

Expand All @@ -13,6 +13,16 @@ using CUDA_Compiler_jll

public launch

function run_and_collect(cmd)
stdout = Pipe()
proc = run(pipeline(ignorestatus(cmd); stdout, stderr=stdout), wait=false)
close(stdout.in)
reader = Threads.@spawn String(read(stdout))
Base.wait(proc)
log = strip(fetch(reader))
return proc, log
end

"""
check_tile_ir_support()

Expand All @@ -38,6 +48,9 @@ function check_tile_ir_support()
else
error("Tile IR is not supported on compute capability $cap ($sm_arch)")
end

# Return bytecode version matching the toolkit
return VersionNumber(cuda_ver.major, cuda_ver.minor)
end

"""
Expand All @@ -58,12 +71,29 @@ function emit_binary(cache::CacheView, mi::Core.MethodInstance;
# Run tileiras to produce CUBIN
input_path = tempname() * ".tile"
output_path = tempname() * ".cubin"
compiled = false
try
write(input_path, bytecode)
run(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $(opts.sm_arch) -O$(opts.opt_level)`)
cmd = addenv(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $(opts.sm_arch) -O$(opts.opt_level)`,
"CUDA_ROOT" => CUDA_Compiler_jll.artifact_dir)
proc, log = run_and_collect(cmd)
if !success(proc)
reason = proc.termsignal > 0 ? "tileiras received signal $(proc.termsignal)" :
"tileiras exited with code $(proc.exitcode)"
msg = "Failed to compile Tile IR ($reason)"
if !isempty(log)
msg *= "\n" * log
end
msg *= "\nIf you think this is a bug, please file an issue and attach $(input_path)"
if parse(Bool, get(ENV, "BUILDKITE", "false"))
run(`buildkite-agent artifact upload $(input_path)`)
end
error(msg)
end
compiled = true
res.cuda_bin = read(output_path)
finally
rm(input_path, force=true)
compiled && rm(input_path, force=true)
rm(output_path, force=true)
end

Expand Down Expand Up @@ -135,7 +165,7 @@ function cuTile.launch(@nospecialize(f), grid, args...;
opt_level::Int=3,
num_ctas::Union{Int, Nothing}=nothing,
occupancy::Union{Int, Nothing}=nothing)
check_tile_ir_support()
bytecode_version = check_tile_ir_support()

# Convert CuArray -> TileArray (and other conversions)
tile_args = map(to_tile_arg, args)
Expand Down Expand Up @@ -166,7 +196,8 @@ function cuTile.launch(@nospecialize(f), grid, args...;
end

# Create cache view with compilation options as sharding keys
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy)
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy,
bytecode_version=bytecode_version)
cache = CacheView{CuTileResults}((:cuTile, opts), world)

# Run cached compilation
Expand Down
17 changes: 14 additions & 3 deletions src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1122,9 +1122,14 @@ Example:
function encode_ForOp!(body::Function, cb::CodeBuilder,
result_types::Vector{TypeId}, iv_type::TypeId,
lower::Value, upper::Value, step::Value,
init_values::Vector{Value})
init_values::Vector{Value};
unsigned_cmp::Bool=false)
encode_varint!(cb.buf, Opcode.ForOp)
encode_typeid_seq!(cb.buf, result_types)
# Flags
if cb.version >= v"13.2"
encode_varint!(cb.buf, unsigned_cmp ? 1 : 0)
end
# Operands: lower, upper, step, init_values...
encode_varint!(cb.buf, 3 + length(init_values))
encode_operand!(cb.buf, lower)
Expand Down Expand Up @@ -1558,7 +1563,9 @@ function encode_NegIOp!(cb::CodeBuilder, result_type::TypeId, source::Value;
overflow::IntegerOverflow=OverflowNone)
encode_varint!(cb.buf, Opcode.NegIOp)
encode_typeid!(cb.buf, result_type)
encode_enum!(cb.buf, overflow)
if cb.version >= v"13.2"
encode_enum!(cb.buf, overflow)
end
encode_operand!(cb.buf, source)
return new_op!(cb)
end
Expand Down Expand Up @@ -1956,9 +1963,13 @@ end
Element-wise hyperbolic tangent.
Opcode: 106
"""
function encode_TanHOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
function encode_TanHOp!(cb::CodeBuilder, result_type::TypeId, source::Value;
rounding_mode::RoundingMode=RoundingFull)
encode_varint!(cb.buf, Opcode.TanHOp)
encode_typeid!(cb.buf, result_type)
if cb.version >= v"13.2"
encode_enum!(cb.buf, rounding_mode)
end
encode_operand!(cb.buf, source)
return new_op!(cb)
end
Expand Down
38 changes: 22 additions & 16 deletions src/bytecode/writer.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Bytecode file writer - handles sections and overall structure

# Bytecode version
const BYTECODE_VERSION = (13, 1, 0)
const DEFAULT_BYTECODE_VERSION = v"13.1"

# Magic number
const MAGIC = UInt8[0x7f, 0x54, 0x69, 0x6c, 0x65, 0x49, 0x52, 0x00] # "\x7fTileIR\x00"
Expand Down Expand Up @@ -97,9 +97,11 @@ mutable struct CodeBuilder
next_value_id::Int
cur_debug_attr::DebugAttrId
num_ops::Int
version::VersionNumber
end

function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, type_table::TypeTable)
function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, type_table::TypeTable;
version::VersionNumber=DEFAULT_BYTECODE_VERSION)
CodeBuilder(
UInt8[],
string_table,
Expand All @@ -108,7 +110,8 @@ function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, t
DebugAttrId[],
0,
DebugAttrId(0), # No debug info
0
0,
version
)
end

Expand Down Expand Up @@ -374,9 +377,10 @@ mutable struct BytecodeWriter
debug_attr_table::DebugAttrTable
debug_info::Vector{Vector{DebugAttrId}}
num_functions::Int
version::VersionNumber
end

function BytecodeWriter()
function BytecodeWriter(; version::VersionNumber=DEFAULT_BYTECODE_VERSION)
string_table = StringTable()
BytecodeWriter(
UInt8[],
Expand All @@ -385,21 +389,21 @@ function BytecodeWriter()
TypeTable(),
DebugAttrTable(string_table),
Vector{Vector{DebugAttrId}}[],
0
0,
version
)
end

"""
Write the bytecode header.
"""
function write_header!(buf::Vector{UInt8})
function write_header!(buf::Vector{UInt8}, version::VersionNumber)
append!(buf, MAGIC)
major, minor, tag = BYTECODE_VERSION
push!(buf, UInt8(major))
push!(buf, UInt8(minor))
# Tag as 2-byte little-endian
push!(buf, UInt8(tag & 0xff))
push!(buf, UInt8((tag >> 8) & 0xff))
push!(buf, UInt8(version.major))
push!(buf, UInt8(version.minor))
# Patch as 2-byte little-endian
push!(buf, UInt8(version.patch & 0xff))
push!(buf, UInt8((version.patch >> 8) & 0xff))
end

"""
Expand Down Expand Up @@ -486,8 +490,9 @@ end
Write complete bytecode to a buffer.
Returns the buffer with all sections.
"""
function write_bytecode!(f::Function, num_functions::Int)
writer = BytecodeWriter()
function write_bytecode!(f::Function, num_functions::Int;
version::VersionNumber=DEFAULT_BYTECODE_VERSION)
writer = BytecodeWriter(; version)

# Function section content
func_buf = UInt8[]
Expand All @@ -502,7 +507,7 @@ function write_bytecode!(f::Function, num_functions::Int)

# Build final output
buf = UInt8[]
write_header!(buf)
write_header!(buf, version)

# Sections in order: Func, Global (if any), Constant, Debug, Type, String, End
write_section!(buf, Section.Func, func_buf, 8)
Expand Down Expand Up @@ -574,7 +579,8 @@ function add_function!(writer::BytecodeWriter, func_buf::Vector{UInt8},
end

# Create code builder for function body
cb = CodeBuilder(writer.string_table, writer.constant_table, writer.type_table)
cb = CodeBuilder(writer.string_table, writer.constant_table, writer.type_table;
version=writer.version)

return cb
end
Expand Down
9 changes: 6 additions & 3 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ const CGOpts = @NamedTuple{
sm_arch::Union{String, Nothing},
opt_level::Int,
num_ctas::Union{Int, Nothing},
occupancy::Union{Int, Nothing}
occupancy::Union{Int, Nothing},
bytecode_version::VersionNumber
}

# Results struct for caching compilation phases
Expand Down Expand Up @@ -394,7 +395,7 @@ function emit_code(cache::CacheView, mi::Core.MethodInstance;
opts = cache.owner[2]

# Generate Tile IR bytecode
bytecode = write_bytecode!(1) do writer, func_buf
bytecode = write_bytecode!(1; version=opts.bytecode_version) do writer, func_buf
emit_kernel!(writer, func_buf, sci, rettype;
name = sanitize_name(string(mi.def.name)),
sm_arch = opts.sm_arch,
Expand Down Expand Up @@ -508,6 +509,7 @@ function code_tiled(io::IO, @nospecialize(f), @nospecialize(argtypes);
opt_level::Int=3,
num_ctas::Union{Int, Nothing}=nothing,
occupancy::Union{Int, Nothing}=nothing,
bytecode_version::VersionNumber=DEFAULT_BYTECODE_VERSION,
world::UInt=Base.get_world_counter())
# Strip Constant types from argtypes for MI lookup, build const_argtypes
stripped, const_argtypes = process_const_argtypes(f, argtypes)
Expand All @@ -518,7 +520,8 @@ function code_tiled(io::IO, @nospecialize(f), @nospecialize(argtypes);
mi = @something(method_instance(f, stripped; world, method_table=cuTileMethodTable),
method_instance(f, stripped; world),
throw(MethodError(f, stripped)))
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy)
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy,
bytecode_version=bytecode_version)
cache = CacheView{CuTileResults}((:cuTile, opts), world)
bytecode = emit_code(cache, mi; const_argtypes)
print(io, disassemble_tileir(bytecode))
Expand Down
15 changes: 15 additions & 0 deletions src/language/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ end
S === () ? Intrinsics.to_scalar(result) : result
end

# Convert mismatched scalar/tile types to match array element type
@inline function atomic_cas(array::TileArray{T}, indices,
expected::TileOrScalar, desired::TileOrScalar;
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
atomic_cas(array, indices, T(expected), T(desired); memory_order, memory_scope)
end

# ============================================================================
# Atomic RMW operations (atomic_add, atomic_xchg)
# ============================================================================
Expand Down Expand Up @@ -150,4 +158,11 @@ for op in (:add, :xchg)
result = Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope)
S === () ? Intrinsics.to_scalar(result) : result
end

# Convert mismatched scalar/tile types to match array element type
@eval @inline function $fname(array::TileArray{T}, indices, val::TileOrScalar;
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
$fname(array, indices, T(val); memory_order, memory_scope)
end
end
3 changes: 2 additions & 1 deletion test/execution/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ end
# Test atomic_xchg: each thread exchanges, last one wins
function atomic_xchg_kernel(arr::ct.TileArray{Int,1})
bid = ct.bid(1)
ct.atomic_xchg(arr, 1, bid + 1;
# bid is 1-indexed (1..n_blocks), val is auto-converted from Int32 to Int
ct.atomic_xchg(arr, 1, bid;
memory_order=ct.MemoryOrder.AcqRel)
return
end
Expand Down
18 changes: 12 additions & 6 deletions test/execution/hints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ using CUDA
b = CUDA.ones(Float32, n) .* 2
c = CUDA.zeros(Float32, n)

ct.launch(vadd_kernel_num_ctas, 64, a, b, c; num_ctas=2)

@test Array(c) ≈ ones(Float32, n) .* 3
if capability(device()) >= v"10"
ct.launch(vadd_kernel_num_ctas, 64, a, b, c; num_ctas=2)
@test Array(c) ≈ ones(Float32, n) .* 3
else
@test_throws "num_cta_in_cga" ct.launch(vadd_kernel_num_ctas, 64, a, b, c; num_ctas=2)
end
end

@testset "launch with occupancy" begin
Expand Down Expand Up @@ -60,9 +63,12 @@ end
b = CUDA.ones(Float32, n) .* 2
c = CUDA.zeros(Float32, n)

ct.launch(vadd_kernel_both_hints, 64, a, b, c; num_ctas=4, occupancy=8)

@test Array(c) ≈ ones(Float32, n) .* 3
if capability(device()) >= v"10"
ct.launch(vadd_kernel_both_hints, 64, a, b, c; num_ctas=4, occupancy=8)
@test Array(c) ≈ ones(Float32, n) .* 3
else
@test_throws "num_cta_in_cga" ct.launch(vadd_kernel_both_hints, 64, a, b, c; num_ctas=4, occupancy=8)
end
end

end
Expand Down
Loading