diff --git a/src/compiler/codegen.jl b/src/compiler/codegen.jl index 564aa8e..6cb0af0 100644 --- a/src/compiler/codegen.jl +++ b/src/compiler/codegen.jl @@ -1,6 +1,9 @@ # Codegen: Julia IR -> Tile IR bytecode include("codegen/utils.jl") +include("codegen/token_keys.jl") # Defines TokenKey, TokenRole, ACQUIRE_TOKEN_KEY +include("codegen/alias_analysis.jl") # Defines alias_analysis_pass! +include("codegen/token_order.jl") # Defines get_alias_set, get_input_token! include("codegen/kernel.jl") include("codegen/control_flow.jl") include("codegen/statements.jl") diff --git a/src/compiler/codegen/alias_analysis.jl b/src/compiler/codegen/alias_analysis.jl new file mode 100644 index 0000000..32b607b --- /dev/null +++ b/src/compiler/codegen/alias_analysis.jl @@ -0,0 +1,217 @@ +""" + AliasTracker + +Tracks alias sets for each SSA value during fixed-point analysis. +""" +mutable struct AliasTracker + dirty::Bool + aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet +end + +AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}()) + +function Base.getindex(tracker::AliasTracker, key) + return get(tracker.aliases, key, ALIAS_UNIVERSE) +end + +function Base.setindex!(tracker::AliasTracker, value::AliasSet, key) + current = get(tracker.aliases, key, nothing) + if current !== value + tracker.dirty = true + tracker.aliases[key] = value + end + return +end + +""" + alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet} + +Perform fixed-point alias analysis on structured IR. +Returns mapping from SSA values to alias sets. +""" +function alias_analysis_pass!(sci::StructuredIRCode) + tracker = AliasTracker() + + # Initialize: each argument gets its own alias set + for (idx, argtype) in enumerate(sci.argtypes) + argtype_unwrapped = CC.widenconst(argtype) + if contains_pointers(argtype_unwrapped) + arg_ref = Argument(idx) + tracker[arg_ref] = Set{Any}([arg_ref]) + end + end + + # Fixed-point iteration + iteration = 0 + max_iterations = 100 + + tracker.dirty = true + while tracker.dirty && iteration < max_iterations + tracker.dirty = false + iteration += 1 + + analyze_block!(tracker, sci.entry) + end + + @debug "Alias analysis converged in $iteration iterations" + + return tracker.aliases +end + +""" + propagate!(tracker::AliasTracker, from, to) + +Propagate alias set from `from` to `to`. +Uses direct assignment when `to` is uninitialized, union otherwise. +""" +function propagate!(tracker::AliasTracker, from, to) + from_aliases = tracker[from] + + if from_aliases === ALIAS_UNIVERSE + # Propagating UNIVERSE is always conservative + tracker[to] = ALIAS_UNIVERSE + return + end + + if haskey(tracker.aliases, to) + # Target already has an alias set union with it + to_aliases = tracker.aliases[to] + new_aliases = union(from_aliases, to_aliases) + if new_aliases != to_aliases + tracker[to] = new_aliases + end + else + # Target not yet analyzed assign directly + tracker[to] = from_aliases + end + return +end + +""" + analyze_block!(tracker::AliasTracker, block) + +Analyze all statements in a block, recursing into nested control flow. +""" +function analyze_block!(tracker::AliasTracker, block) + for (ssa_idx, entry) in block.body + if entry.stmt isa ControlFlowOp + analyze_control_flow!(tracker, entry.stmt) + else + analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt) + end + end + return +end + +# Recurse into nested control flow regions +function analyze_control_flow!(tracker::AliasTracker, op::IfOp) + analyze_block!(tracker, op.then_region) + return analyze_block!(tracker, op.else_region) +end + +function analyze_control_flow!(tracker::AliasTracker, op::ForOp) + return analyze_block!(tracker, op.body) +end + +function analyze_control_flow!(tracker::AliasTracker, op::WhileOp) + analyze_block!(tracker, op.before) + return analyze_block!(tracker, op.after) +end + +function analyze_control_flow!(tracker::AliasTracker, op::LoopOp) + return analyze_block!(tracker, op.body) +end + +# Fallback for unknown control flow ops +function analyze_control_flow!(::AliasTracker, ::ControlFlowOp) + return +end + +""" + analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt) + +Analyze a single statement and propagate aliases. +Handles both `:call` and `:invoke` expression forms. +""" +function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt) + if stmt isa Expr && (stmt.head === :call || stmt.head === :invoke) + # Normalize :call and :invoke into (func, operands) + # :call -> args = [func, operands...] + # :invoke -> args = [MethodInstance, func, operands...] + if stmt.head === :call + func = stmt.args[1] + operands = @view stmt.args[2:end] + else # :invoke + func = stmt.args[2] + operands = @view stmt.args[3:end] + end + + # Resolve func to its runtime value for intrinsic matching. + # In :invoke, func may already be the function object (not a GlobalRef). + resolved_func = if func isa GlobalRef + try + getfield(func.mod, func.name) + catch + nothing + end + else + func # Direct function value (common in :invoke) + end + + # getfield: propagate from parent + if func === GlobalRef(Core, :getfield) && length(operands) >= 1 + field = length(operands) >= 2 ? operands[2] : nothing + + # For TileArray.ptr field access, propagate pointer alias + if field isa QuoteNode && field.value === :ptr + propagate!(tracker, operands[1], ssa) + else + # Conservatively mark as UNIVERSE for non-pointer fields + tracker[ssa] = ALIAS_UNIVERSE + end + + # Pointer arithmetic: propagate from pointer operand + elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-) + for arg in operands + # Find the pointer argument and propagate + arg_aliases = tracker[arg] + if arg_aliases !== ALIAS_UNIVERSE && arg_aliases isa Set + propagate!(tracker, arg, ssa) + break + end + end + + # View construction: propagate alias from first operand + elseif is_view_constructor(resolved_func) + if length(operands) >= 1 + propagate!(tracker, operands[1], ssa) + end + + # Default: unknown operation -> UNIVERSE + else + tracker[ssa] = ALIAS_UNIVERSE + end + + elseif stmt isa ReturnNode + # No alias propagation needed + + else + # Unknown statement type -> conservative + tracker[ssa] = ALIAS_UNIVERSE + end + return +end + +# Helper functions +contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr) + +""" + is_view_constructor(func) -> Bool + +Check if a resolved function is a tensor/partition view constructor. +These propagate alias identity from their first operand. +""" +function is_view_constructor(func) + return func === Intrinsics.make_tensor_view || + func === Intrinsics.make_partition_view +end diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index 50eed22..e63b0c7 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -88,10 +88,14 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_ # Save token before branches token_before = ctx.token + # Save token_map before branches + token_map_before = copy(ctx.token_map) + # Emit IfOp with callback-based region building then_body = function(_) saved_block_args = copy(ctx.block_args) ctx.token = token_before # Reset to pre-branch token + ctx.token_map = copy(token_map_before) # Reset token_map too emit_block!(ctx, then_blk) if then_blk.terminator === nothing encode_YieldOp!(ctx.cb, [ctx.token]) @@ -102,6 +106,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_ else_body = function(_) saved_block_args = copy(ctx.block_args) ctx.token = token_before # Reset to pre-branch token + ctx.token_map = copy(token_map_before) # Reset token_map too emit_block!(ctx, else_blk) if else_blk.terminator === nothing encode_YieldOp!(ctx.cb, [ctx.token]) @@ -114,6 +119,12 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_ # Last result is the merged token from both branches ctx.token = results[end] + # Merge token_map from both branches + # Conservatively reset to token_before for all keys + for key in keys(ctx.token_map) + ctx.token_map[key] = results[end] + end + # Store results at IfOp's SSA index (may be empty for void-returning ifs) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) end @@ -164,6 +175,9 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), # Number of user result types (excluding token) n_user_results = n_carries + # Save token_map before loop + token_map_before = copy(ctx.token_map) + # Emit ForOp with callback-based region building body_builder = function(block_args) saved_block_args = copy(ctx.block_args) @@ -196,6 +210,12 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), # Last result is the token ctx.token = results[end] + # Update token_map after loop + # Conservatively update all keys to the merged token + for key in keys(token_map_before) + ctx.token_map[key] = results[end] + end + # Store results at the loop's SSA index (may be empty for void-returning loops) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) end @@ -230,6 +250,9 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type) # Number of user result types (excluding token) n_user_results = n_carries + # Save token_map before loop + token_map_before = copy(ctx.token_map) + # Emit LoopOp with callback-based region building body_builder = function(block_args) saved_block_args = copy(ctx.block_args) @@ -266,6 +289,12 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type) # Last result is the token ctx.token = results[end] + # Update token_map after loop + # Conservatively update all keys to the merged token + for key in keys(token_map_before) + ctx.token_map[key] = results[end] + end + # Store results at the loop's SSA index (may be empty for void-returning loops) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) end @@ -301,6 +330,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ # Number of user result types (excluding token) n_user_results = n_carries + # Save token_map before loop + token_map_before = copy(ctx.token_map) + # Emit WhileOp as cuda_tile.loop with conditional break pattern # MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals } # Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue } @@ -396,6 +428,12 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ # Last result is the token ctx.token = results[end] + # Update token_map after loop + # Conservatively update all keys to the merged token + for key in keys(token_map_before) + ctx.token_map[key] = results[end] + end + # Store results at the loop's SSA index (may be empty for void-returning loops) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) end diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index 203bcbe..3fb09a0 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -152,10 +152,30 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, cache_tensor_view!(ctx, arg_idx) end + # Run alias analysis FIRST + alias_result = alias_analysis_pass!(sci) + ctx.alias_result = alias_result + # Create memory ordering token token_type = Token(tt) ctx.token_type = token_type - ctx.token = encode_MakeTokenOp!(cb, token_type) + root_token = encode_MakeTokenOp!(cb, token_type) + + ctx.global_token = root_token + ctx.token = root_token + + # Initialize token map with root token for all alias sets + # Default: all tokens start at root + ctx.token_map = Dict{TokenKey, Value}() + + unique_alias_sets = Set(values(alias_result)) + for alias_set in unique_alias_sets + ctx.token_map[last_op_key(alias_set)] = root_token + ctx.token_map[last_store_key(alias_set)] = root_token + end + + # ACQUIRE token also starts at root + ctx.token_map[ACQUIRE_TOKEN_KEY] = root_token # Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp) hoist_returns!(ctx.sci.entry) diff --git a/src/compiler/codegen/token_keys.jl b/src/compiler/codegen/token_keys.jl new file mode 100644 index 0000000..07c448e --- /dev/null +++ b/src/compiler/codegen/token_keys.jl @@ -0,0 +1,38 @@ +# Token role enum +@enum TokenRole LAST_OP LAST_STORE + +# Acquire token key (singleton) +struct AcquireTokenKey end +const ACQUIRE_TOKEN_KEY = AcquireTokenKey() + +# Alias token key (per alias set and role) +struct AliasTokenKey + alias_set::AliasSet + role::TokenRole +end + +# Union type for all token keys +const TokenKey = Union{AliasTokenKey, AcquireTokenKey} + +# Helper constructors +""" + last_op_key(alias_set::AliasSet) -> AliasTokenKey + +Create a TokenKey for the last operation (load or store) on an alias set. +""" +last_op_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_OP) + +""" + last_store_key(alias_set::AliasSet) -> AliasTokenKey + +Create a TokenKey for the last store operation on an alias set. +""" +last_store_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_STORE) + +# Make TokenKey hashable for use in Dict +Base.hash(key::AliasTokenKey, h::UInt) = hash((key.alias_set, key.role), h) +Base.:(==)(a::AliasTokenKey, b::AliasTokenKey) = + a.alias_set == b.alias_set && a.role == b.role + +Base.hash(::AcquireTokenKey, h::UInt) = hash(:ACQUIRE_TOKEN_KEY, h) +Base.:(==)(::AcquireTokenKey, ::AcquireTokenKey) = true diff --git a/src/compiler/codegen/token_order.jl b/src/compiler/codegen/token_order.jl new file mode 100644 index 0000000..2ab0432 --- /dev/null +++ b/src/compiler/codegen/token_order.jl @@ -0,0 +1,128 @@ +""" + get_input_var(args) -> Any + +Extract the pointer/array variable from memory operation arguments. +""" +function get_input_var(args) + return args[1] +end + +""" + get_alias_set(ctx::CGCtx, var) -> AliasSet + +Get the alias set for a variable from analysis results. +""" +function get_alias_set(ctx::CGCtx, var) + # Trace to source + source = trace_to_source(ctx, var) + + # Lookup in alias results + return get(ctx.alias_result, source, ALIAS_UNIVERSE) +end + +""" + collect_join_tokens(ctx::CGCtx, token_key::TokenKey, memory_order=nothing) -> Vector{Value} + +Collect all tokens that need to be joined for synchronization. +Based on Python's `_collect_join_tokens`. +""" +function collect_join_tokens(ctx::CGCtx, token_key::TokenKey, memory_order = nothing) + tokens_to_join = [ctx.token_map[token_key]] + + for (other_key, other_token) in ctx.token_map + should_join = false + + # Join with ACQUIRE token + if other_key isa AcquireTokenKey + should_join = true + + # Join if alias sets overlap + elseif other_key isa AliasTokenKey && token_key isa AliasTokenKey + # Release memory order: join with all LAST_OP tokens + if memory_order !== nothing && has_release_order(memory_order) + should_join = other_key.role == LAST_OP + end + + # Alias set overlap: join if same role and sets overlap + if other_key.role == token_key.role + alias_overlap = !(other_key.alias_set isa AliasUniverse) && + !(token_key.alias_set isa AliasUniverse) && + !isempty(intersect(other_key.alias_set, token_key.alias_set)) + should_join = should_join || alias_overlap + end + end + + if should_join && !(other_token in tokens_to_join) + push!(tokens_to_join, other_token) + end + end + + return tokens_to_join +end + +""" + get_input_token!(ctx::CGCtx, token_key::TokenKey, memory_order=nothing) + -> (Value, Union{Nothing, JoinOp}) + +Get the input token for an operation, potentially creating a join operation. +""" +function get_input_token!(ctx::CGCtx, token_key::TokenKey, memory_order = nothing) + + if !haskey(ctx.token_map, token_key) + @warn "Token key not found in token_map!" token_key available_keys=keys(ctx.token_map) + # Fallback to root token + return (ctx.token_map[ACQUIRE_TOKEN_KEY], nothing) + end + tokens_to_join = collect_join_tokens(ctx, token_key, memory_order) + + if length(tokens_to_join) == 1 + return (tokens_to_join[1], nothing) + end + + # Join multiple tokens + result_token = encode_JoinTokensOp!(ctx.cb, ctx.token_type, tokens_to_join) + + return (result_token, nothing) # Return nothing for join_op since its already been emitted +end + +""" + trace_to_source(ctx::CGCtx, var) -> Any + +Trace a value back to its original source (Argument, SSAValue). +""" +function trace_to_source(ctx::CGCtx, var) + # Returns if its an Argument or SSAValue + if var isa Argument || var isa SSAValue + return var + end + + # Resolve for SlothNumber + if var isa SlotNumber + tv = get(ctx.slots, var.id, nothing) + if tv !== nothing && is_arg_ref(tv) + arg_idx, _ = tv.arg_ref + return Argument(arg_idx) + end + end + + # Generic emit_value resolution + tv = emit_value!(ctx, var) + if tv !== nothing && is_arg_ref(tv) + arg_idx, _ = tv.arg_ref + return Argument(arg_idx) + end + + # Return as is for unknown + return var +end + +""" + has_release_order(memory_order) -> Bool + +Check if memory order has release semantics. +For now, returns false (no memory order support yet). +""" +function has_release_order(memory_order) + # TODO: Implement proper memory order checking when needed + return false +end diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index 117418a..05713b1 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -2,6 +2,38 @@ # # Core types (CGVal, CGCtx) and helper functions for Tile IR code generation. + +#============================================================================= + Alias Analysis Types +=============================================================================# + +""" + AliasUniverse + +Singleton type representing the universal alias set (everything may alias everything). +""" +struct AliasUniverse end +const ALIAS_UNIVERSE = AliasUniverse() + +# Universe behaves specially in set operations +Base.union(::AliasUniverse, ::AliasUniverse) = ALIAS_UNIVERSE +Base.union(::AliasUniverse, other) = ALIAS_UNIVERSE +Base.union(other, ::AliasUniverse) = ALIAS_UNIVERSE +Base.intersect(::AliasUniverse, other) = other +Base.intersect(other, ::AliasUniverse) = other +Base.:(==)(::AliasUniverse, ::AliasUniverse) = true +Base.:(==)(::AliasUniverse, other) = false +Base.:(==)(other, ::AliasUniverse) = false + +""" + AliasSet + +Union type representing either a concrete set of values that may alias, +or the universal alias set (ALIAS_UNIVERSE). +""" +const AliasSet = Union{Set{Any}, AliasUniverse} + + #============================================================================= IRError: Exception type for IR compilation errors =============================================================================# @@ -163,7 +195,9 @@ mutable struct CGCtx tt::TypeTable sci::StructuredIRCode - # Memory ordering token + # Memory ordering token (kept for backward compatibility) + tokens::Dict{UInt64, Value} + global_token::Union{Value, Nothing} token::Union{Value, Nothing} token_type::Union{TypeId, Nothing} @@ -175,6 +209,10 @@ mutable struct CGCtx # Compilation cache (needed for combiner compilation) cache::CacheView + + # Alias-aware token system + alias_result::Dict{Any, AliasSet} # From alias analysis + token_map::Dict{Any, Value} # TokenKey -> current token Value end function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, @@ -191,7 +229,18 @@ function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}(), Dict{Int, Type}(), Dict{Int, Tuple{Value, TypeId}}(), - cb, tt, sci, token, token_type, type_cache, sm_arch, cache, + cb, + tt, + sci, + Dict{UInt64, Value}(), # tokens (old system) + nothing, # global_token (old system) + token, # token (old system) + token_type, + type_cache, + sm_arch, + cache, + Dict{Any, AliasSet}(), # alias_result + Dict{Any, Value}(), # token_map ) end diff --git a/src/compiler/intrinsics/memory.jl b/src/compiler/intrinsics/memory.jl index 0be636f..2305674 100644 --- a/src/compiler/intrinsics/memory.jl +++ b/src/compiler/intrinsics/memory.jl @@ -41,6 +41,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) # Check if mask is provided (arg 3 is not nothing) has_mask = length(args) >= 3 && get_constant(ctx, args[3]) !== nothing + # Get alias set and token key + alias_set = get_alias_set(ctx, args[1]) + last_store_key_val = last_store_key(alias_set) + + # Load depends on LAST_STORE (read after write) + input_token, _ = get_input_token!(ctx, last_store_key_val, nothing) + if has_mask # Get mask tile (arg 3) mask_tv = emit_value!(ctx, args[3]) @@ -56,15 +63,22 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers; mask=mask, padding_value=padding, - token=ctx.token, + token=input_token, optimization_hints) else # Load without mask tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers; - token=ctx.token, + token=input_token, optimization_hints) end - ctx.token = new_token + + # Eagerly join with LAST_OP token + last_op_key_val = last_op_key(alias_set) + last_op_token = get(ctx.token_map, last_op_key_val, new_token) + new_last_op_token = encode_JoinTokensOp!(ctx.cb, token_type, [last_op_token, new_token]) + + # Update token map + ctx.token_map[last_op_key_val] = new_last_op_token result_jltype = Tile{elem_type, Tuple{tile_shape...}} CGVal(tile_val, result_tile_type, result_jltype, tile_shape) @@ -105,6 +119,14 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) # Check if mask is provided (arg 4 is not nothing) has_mask = length(args) >= 4 && get_constant(ctx, args[4]) !== nothing + # Get alias set and token key + alias_set = get_alias_set(ctx, args[1]) + last_op_key_val = last_op_key(alias_set) + last_store_key_val = last_store_key(alias_set) + + # Store depends on LAST_OP (write after read/write) + input_token, _ = get_input_token!(ctx, last_op_key_val, nothing) + if has_mask # Get mask tile (arg 4) mask_tv = emit_value!(ctx, args[4]) @@ -114,15 +136,18 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) # Store with mask new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values; mask=mask, - token=ctx.token, + token=input_token, optimization_hints) else # Store without mask new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values; - token=ctx.token, + token=input_token, optimization_hints) end - ctx.token = new_token + + # Update both LAST_OP and LAST_STORE + ctx.token_map[last_op_key_val] = new_token + ctx.token_map[last_store_key_val] = new_token nothing end diff --git a/src/compiler/intrinsics/views.jl b/src/compiler/intrinsics/views.jl index fff19b1..bca574a 100644 --- a/src/compiler/intrinsics/views.jl +++ b/src/compiler/intrinsics/views.jl @@ -132,10 +132,26 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) + # Get alias set and token key + alias_set = get_alias_set(ctx, args[1]) # Use partition view as source + last_store_key_val = last_store_key(alias_set) + + # Load depends on LAST_STORE (read after write) + input_token, _ = get_input_token!(ctx, last_store_key_val, nothing) + # Load tile with token - tile_val, new_token = encode_LoadViewTkoOp!(cb, tile_type, token_type, pv_arg.v, index_vals; - token=ctx.token, optimization_hints) - ctx.token = new_token + tile_val, result_token = encode_LoadViewTkoOp!( + cb, tile_type, token_type, pv_arg.v, index_vals; + token = input_token, optimization_hints + ) + + # Eagerly join with LAST_OP token + last_op_key_val = last_op_key(alias_set) + last_op_token = get(ctx.token_map, last_op_key_val, result_token) + new_last_op_token = encode_JoinTokensOp!(ctx.cb, token_type, [last_op_token, result_token]) + + # Update token map + ctx.token_map[last_op_key_val] = new_last_op_token CGVal(tile_val, tile_type, Tile{elem_type, Tuple{tile_shape...}}, tile_shape) end @@ -404,11 +420,22 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) - # Store tile with token - token_type = Token(tt) - new_token = encode_StoreViewTkoOp!(cb, token_type, tile_val, pv_arg.v, index_vals; - token=ctx.token, optimization_hints) - ctx.token = new_token + # Get alias set and token key +alias_set = get_alias_set(ctx, args[1]) # Use partition view as source +last_op_key_val = last_op_key(alias_set) +last_store_key_val = last_store_key(alias_set) + +# Store depends on LAST_OP (write after read/write) +input_token, _ = get_input_token!(ctx, last_op_key_val, nothing) + +# Store tile with token +token_type = Token(tt) +result_token = encode_StoreViewTkoOp!(cb, token_type, tile_val, pv_arg.v, index_vals; + token=input_token, optimization_hints) + +# Update both LAST_OP and LAST_STORE +ctx.token_map[last_op_key_val] = result_token +ctx.token_map[last_store_key_val] = result_token nothing end