Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/compiler/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Codegen: Julia IR -> Tile IR bytecode

include("codegen/utils.jl")
include("codegen/token_keys.jl") # Defines TokenKey, TokenRole, ACQUIRE_TOKEN_KEY
include("codegen/alias_analysis.jl") # Defines alias_analysis_pass!
include("codegen/token_order.jl") # Defines get_alias_set, get_input_token!
include("codegen/kernel.jl")
include("codegen/control_flow.jl")
include("codegen/statements.jl")
Expand Down
153 changes: 153 additions & 0 deletions src/compiler/codegen/alias_analysis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
AliasTracker

Tracks alias sets for each SSA value during fixed-point analysis.
"""
mutable struct AliasTracker
dirty::Bool
aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet
end

AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}())

function Base.getindex(tracker::AliasTracker, key)
return get(tracker.aliases, key, ALIAS_UNIVERSE)
end

function Base.setindex!(tracker::AliasTracker, value::AliasSet, key)
current = get(tracker.aliases, key, nothing)
return if current !== value
tracker.dirty = true
tracker.aliases[key] = value
end
end

"""
alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet}

Perform fixed-point alias analysis on structured IR.
Returns mapping from SSA values to alias sets.
"""
function alias_analysis_pass!(sci::StructuredIRCode)
tracker = AliasTracker()

# Initialize: each argument gets its own alias set
for (idx, argtype) in enumerate(sci.argtypes)
argtype_unwrapped = CC.widenconst(argtype)
if contains_pointers(argtype_unwrapped)
arg_ref = Argument(idx)
tracker[arg_ref] = Set{Any}([arg_ref])
end
end

# Fixed-point iteration
iteration = 0
max_iterations = 100

tracker.dirty = true
while tracker.dirty && iteration < max_iterations
tracker.dirty = false
iteration += 1

analyze_block!(tracker, sci.entry)
end

@debug "Alias analysis converged in $iteration iterations"

return tracker.aliases
end

"""
propagate!(tracker::AliasTracker, from, to)

Propagate alias set from `from` to `to` (union operation).
"""
function propagate!(tracker::AliasTracker, from, to)
from_aliases = tracker[from]
to_aliases = tracker[to]

# Union the alias sets
new_aliases = union(from_aliases, to_aliases)

return if new_aliases != to_aliases
tracker[to] = new_aliases
end
end

"""
analyze_block!(tracker::AliasTracker, block)

Analyze all statements in a block.
"""
function analyze_block!(tracker::AliasTracker, block)
# Block has args, body, terminator
# body is an iterator that yields (ssa_idx, entry) where entry has .stmt and .typ
for (ssa_idx, entry) in block.body
analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt)
end
return
Comment on lines +83 to +88
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No recursion into nested ops?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flat traversal was intentional as a first pass wanted to establish correct alias propagation at the top level before handling the loop/branch cases, since nested blocks raise questions about how loop carried pointer SSA values should inherit alias sets across iterations. Will add the recursion now and descend into nested blocks from analyze_statement!. Have a benchmark with an interleaved multi-array kernel in progress to confirm per-alias chains form correctly across the branch boundaries before pushing.

end

"""
analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)

Analyze a single statement and propagate aliases.
"""
function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
return if stmt isa Expr && stmt.head === :call
func = stmt.args[1]

# getfield: propagate from parent
if func === GlobalRef(Core, :getfield) && length(stmt.args) >= 2
parent = stmt.args[2]
field = length(stmt.args) >= 3 ? stmt.args[3] : nothing

# For TileArray.ptr field access, propagate pointer alias
if field isa QuoteNode && field.value === :ptr
propagate!(tracker, parent, ssa)
else
# Conservatively mark as UNIVERSE for non-pointer fields
tracker[ssa] = ALIAS_UNIVERSE
end

# Pointer arithmetic: propagate from pointer operand
elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-)
for arg in stmt.args[2:end]
# Find the pointer argument and propagate
arg_aliases = tracker[arg]
if arg_aliases !== ALIAS_UNIVERSE || arg_aliases isa Set
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else can arg_aliases be if not ALIAS_UNIVERSE or an AliasSet?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this condition is redundant. Will fix it.

propagate!(tracker, arg, ssa)
break
end
end

# TileArray construction: propagate from pointer argument
elseif is_tile_array_constructor(func)
# First argument is typically the pointer
if length(stmt.args) >= 2
propagate!(tracker, stmt.args[2], ssa)
end

# Default: unknown operation -> UNIVERSE
else
tracker[ssa] = ALIAS_UNIVERSE
end

# Control flow operations need special handling
elseif stmt isa ReturnNode
# No alias propagation needed

else
# Unknown statement type -> conservative
tracker[ssa] = ALIAS_UNIVERSE
end
end

# Helper functions
contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr)

function is_tile_array_constructor(func)
# Check if this is a TileArray constructor call
# You'll need to detect the specific GlobalRef for TileArray
return false # TODO: implement
end
Comment on lines +149 to +153
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TileArrays are never constructed in the kernel. Or do you mean tensor and partition views?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, misnaming on my part. Renaming this to is_partition_or_tensor_view and implementing it to detect partition/tensor view call sites. The intent was to identify the point where a new SSA value gets a distinct alias set rooted at a specific base argument

38 changes: 38 additions & 0 deletions src/compiler/codegen/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,14 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
# Save token before branches
token_before = ctx.token

# Save token_map before branches
token_map_before = copy(ctx.token_map)

# Emit IfOp with callback-based region building
then_body = function(_)
saved_block_args = copy(ctx.block_args)
ctx.token = token_before # Reset to pre-branch token
ctx.token_map = copy(token_map_before) # Reset token_map too
emit_block!(ctx, then_blk)
if then_blk.terminator === nothing
encode_YieldOp!(ctx.cb, [ctx.token])
Expand All @@ -102,6 +106,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
else_body = function(_)
saved_block_args = copy(ctx.block_args)
ctx.token = token_before # Reset to pre-branch token
ctx.token_map = copy(token_map_before) # Reset token_map too
emit_block!(ctx, else_blk)
if else_blk.terminator === nothing
encode_YieldOp!(ctx.cb, [ctx.token])
Expand All @@ -114,6 +119,12 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
# Last result is the merged token from both branches
ctx.token = results[end]

# Merge token_map from both branches
# Conservatively reset to token_before for all keys
for key in keys(ctx.token_map)
ctx.token_map[key] = results[end]
end

# Store results at IfOp's SSA index (may be empty for void-returning ifs)
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
end
Expand Down Expand Up @@ -164,6 +175,9 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
# Number of user result types (excluding token)
n_user_results = n_carries

# Save token_map before loop
token_map_before = copy(ctx.token_map)

# Emit ForOp with callback-based region building
body_builder = function(block_args)
saved_block_args = copy(ctx.block_args)
Expand Down Expand Up @@ -196,6 +210,12 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
# Last result is the token
ctx.token = results[end]

# Update token_map after loop
# Conservatively update all keys to the merged token
for key in keys(token_map_before)
ctx.token_map[key] = results[end]
end

# Store results at the loop's SSA index (may be empty for void-returning loops)
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
end
Expand Down Expand Up @@ -230,6 +250,9 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
# Number of user result types (excluding token)
n_user_results = n_carries

# Save token_map before loop
token_map_before = copy(ctx.token_map)

# Emit LoopOp with callback-based region building
body_builder = function(block_args)
saved_block_args = copy(ctx.block_args)
Expand Down Expand Up @@ -266,6 +289,12 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
# Last result is the token
ctx.token = results[end]

# Update token_map after loop
# Conservatively update all keys to the merged token
for key in keys(token_map_before)
ctx.token_map[key] = results[end]
end

# Store results at the loop's SSA index (may be empty for void-returning loops)
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
end
Expand Down Expand Up @@ -301,6 +330,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
# Number of user result types (excluding token)
n_user_results = n_carries

# Save token_map before loop
token_map_before = copy(ctx.token_map)

# Emit WhileOp as cuda_tile.loop with conditional break pattern
# MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals }
# Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue }
Expand Down Expand Up @@ -396,6 +428,12 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
# Last result is the token
ctx.token = results[end]

# Update token_map after loop
# Conservatively update all keys to the merged token
for key in keys(token_map_before)
ctx.token_map[key] = results[end]
end

# Store results at the loop's SSA index (may be empty for void-returning loops)
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
end
Expand Down
22 changes: 21 additions & 1 deletion src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,30 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
cache_tensor_view!(ctx, arg_idx)
end

# Run alias analysis FIRST
alias_result = alias_analysis_pass!(sci)
ctx.alias_result = alias_result

# Create memory ordering token
token_type = Token(tt)
ctx.token_type = token_type
ctx.token = encode_MakeTokenOp!(cb, token_type)
root_token = encode_MakeTokenOp!(cb, token_type)

ctx.global_token = root_token
ctx.token = root_token

# Initialize token map with root token for all alias sets
# Default: all tokens start at root
ctx.token_map = Dict{TokenKey, Value}()

unique_alias_sets = Set(values(alias_result))
for alias_set in unique_alias_sets
ctx.token_map[last_op_key(alias_set)] = root_token
ctx.token_map[last_store_key(alias_set)] = root_token
end

# ACQUIRE token also starts at root
ctx.token_map[ACQUIRE_TOKEN_KEY] = root_token

# Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp)
hoist_returns!(ctx.sci.entry)
Expand Down
38 changes: 38 additions & 0 deletions src/compiler/codegen/token_keys.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Token role enum
@enum TokenRole LAST_OP LAST_STORE

# Acquire token key (singleton)
struct AcquireTokenKey end
const ACQUIRE_TOKEN_KEY = AcquireTokenKey()

# Alias token key (per alias set and role)
struct AliasTokenKey
alias_set::AliasSet
role::TokenRole
end

# Union type for all token keys
const TokenKey = Union{AliasTokenKey, AcquireTokenKey}

# Helper constructors
"""
last_op_key(alias_set::AliasSet) -> AliasTokenKey

Create a TokenKey for the last operation (load or store) on an alias set.
"""
last_op_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_OP)

"""
last_store_key(alias_set::AliasSet) -> AliasTokenKey

Create a TokenKey for the last store operation on an alias set.
"""
last_store_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_STORE)

# Make TokenKey hashable for use in Dict
Base.hash(key::AliasTokenKey, h::UInt) = hash((key.alias_set, key.role), h)
Base.:(==)(a::AliasTokenKey, b::AliasTokenKey) =
a.alias_set == b.alias_set && a.role == b.role

Base.hash(::AcquireTokenKey, h::UInt) = hash(:ACQUIRE_TOKEN_KEY, h)
Base.:(==)(::AcquireTokenKey, ::AcquireTokenKey) = true
Loading