diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 525c844f6..1e04b4eec 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -17,7 +17,7 @@ import torch.utils._pytree as pytree from collections import namedtuple -from .symbolic_constraints import SymbolicAlias +from .symbolic_constraints import SymbolicConstraint from ..compiler.ir import ( Attribute, DenseElementsAttr, diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index cdf737b42..065447530 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -28,7 +28,7 @@ WorkgroupConstraint, ) from .assumptions import Assumption -from .symbolic_constraints import SymbolicAlias +from .symbolic_constraints import SymbolicConstraint from .._support.tracing import CapturedTrace, IndexingContext from .._support.indexing import IndexSymbol, IndexSequence from ..lang.global_symbols import * @@ -469,7 +469,7 @@ def set_thread_independent_index( constraints = [ c for c in constraints - if not isinstance(c, (HardwareConstraint, Assumption, SymbolicAlias)) + if not isinstance(c, (HardwareConstraint, Assumption, SymbolicConstraint)) ] index = {} @@ -628,7 +628,7 @@ def should_update_index( source: CustomOp, source_index: dict[IndexSymbol, IndexSequence], source_vector_shapes: dict[IndexSymbol, int], - symbolic_constraints: list[SymbolicAlias], + symbolic_constraints: list[SymbolicConstraint], ): # Get symbolic shape without any aliased variables. aliased_dims = [x.source for x in symbolic_constraints] @@ -656,7 +656,9 @@ def should_update_index( return True -def append_aliased_shapes(source: CustomOp, symbolic_constraints: list[SymbolicAlias]): +def append_aliased_shapes( + source: CustomOp, symbolic_constraints: list[SymbolicConstraint] +): """ Append the aliased shapes to the vector shapes of the source, if they are present in the source index. @@ -677,7 +679,7 @@ def propagate_index( workgroup_constraints: list[WorkgroupConstraint], mma_index: dict[MMA, dict[IndexSymbol, int]], visited: set[CustomOp], - symbolic_constraints: list[SymbolicAlias], + symbolic_constraints: list[SymbolicConstraint], ): """ Propagate the index and vector shapes through the graph @@ -736,7 +738,7 @@ def set_thread_dependent_index( workgroup_constraints = [ c for c in constraints if isinstance(c, WorkgroupConstraint) ] - symbolic_constraints = [c for c in constraints if isinstance(c, SymbolicAlias)] + symbolic_constraints = [c for c in constraints if isinstance(c, SymbolicConstraint)] for source in sources: visited = visited.union(set([x for x in sources])) visited.remove(source) diff --git a/iree/turbine/kernel/wave/symbolic_constraints.py b/iree/turbine/kernel/wave/symbolic_constraints.py index b51fe7350..9963c3e4f 100644 --- a/iree/turbine/kernel/wave/symbolic_constraints.py +++ b/iree/turbine/kernel/wave/symbolic_constraints.py @@ -17,13 +17,13 @@ @dataclass -class SymbolicAlias: +class SymbolicConstraint: """ A constraint of the form `tkw.SymbolicConstraint(K, SYMBOLIC_K)` specifies that the relationship between the source and target symbols is given by source = source_to_target(target). - SymbolicAliases are modeled in the compiler as additional workgroup, wave, + SymbolicConstraintes are modeled in the compiler as additional workgroup, wave, and tiling constraints that are derived from the source. They are ignored during expansion and utilize the same workgroup and wave ids as the target symbol. diff --git a/iree/turbine/kernel/wave/templates/attention_common.py b/iree/turbine/kernel/wave/templates/attention_common.py new file mode 100644 index 000000000..eac363f3a --- /dev/null +++ b/iree/turbine/kernel/wave/templates/attention_common.py @@ -0,0 +1,46 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +from typing import Optional + +import iree.turbine.kernel.lang as tkl + + +@dataclass +class AttentionShape: + num_query_heads: int + num_kv_heads: int + head_size: int + head_size_kv: int + # ----------------------- + # Prefill specific + num_seqs: Optional[int] = None + max_seq_len: Optional[int] = None + total_seq_len: Optional[int] = None + # ----------------------- + # Vanilla attention + query_seq_len: Optional[int] = None + kv_seq_len: Optional[int] = None + + +# Commonly-used attention symbols +H = tkl.sym.H +H_Q = tkl.sym.H +H_KV = tkl.sym.H +N_Q = tkl.sym.N_D +N_KV = tkl.sym.N_KV +D_Q = tkl.sym.D_Q +D_KV = tkl.sym.D_KV + +BLOCK_H = tkl.sym.BLOCK_H +BLOCK_H_Q = tkl.sym.BLOCK_H +BLOCK_H_KV = tkl.sym.BLOCK_H +BLOCK_N_Q = tkl.sym.BLOCK_N_D +BLOCK_N_KV = tkl.sym.BLOCK_N_KV +BLOCK_D_Q = tkl.sym.BLOCK_D_Q +BLOCK_D_KV = tkl.sym.BLOCK_D_KV + diff --git a/iree/turbine/kernel/wave/templates/decode_attention.py b/iree/turbine/kernel/wave/templates/decode_attention.py index e6648c0b8..ae24589f9 100644 --- a/iree/turbine/kernel/wave/templates/decode_attention.py +++ b/iree/turbine/kernel/wave/templates/decode_attention.py @@ -13,7 +13,7 @@ get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, ) -from ..symbolic_constraints import SymbolicAlias +from ..symbolic_constraints import SymbolicConstraint import sympy from enum import Enum import math @@ -64,7 +64,7 @@ def phase_0_constraints(): constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / K_WAVES)] constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 3)] constraints += [ - SymbolicAlias(U, K2, lambda x: sympy.ceiling(x / (BLOCK_K2 / K_WAVES))) + SymbolicConstraint(U, K2, lambda x: sympy.ceiling(x / (BLOCK_K2 / K_WAVES))) ] vector_shapes = {B: 0} waves_per_block = (M_WAVES, N_WAVES, K_WAVES) diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 1eb1f231b..7040518fe 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -8,7 +8,7 @@ import torch.fx as fx import inspect -from .symbolic_constraints import SymbolicAlias +from .symbolic_constraints import SymbolicConstraint from ..compiler import builder, dispatch_codegen, kernel_codegen, host_codegen from ..compiler.ir import Context, Operation @@ -141,7 +141,7 @@ def symbolic_constraints(self) -> list[HardwareConstraint]: return [ constraint for constraint in self.constraints - if isinstance(constraint, SymbolicAlias) + if isinstance(constraint, SymbolicConstraint) ] def _trace(self) -> CapturedTrace: @@ -230,7 +230,7 @@ def get_workgroup_dims(self) -> list[int]: """ # Ignore aliased variables. They will be handled separately. aliased_dims = [ - x.source for x in self.constraints if isinstance(x, SymbolicAlias) + x.source for x in self.constraints if isinstance(x, SymbolicConstraint) ] workgroup_dims = { x.workgroup_dim: x @@ -246,7 +246,7 @@ def update_aliased_workgroup_constraints( This function updates the wg_dim for aliased workgroup constraints. """ aliased_dims = [ - x.source for x in self.constraints if isinstance(x, SymbolicAlias) + x.source for x in self.constraints if isinstance(x, SymbolicConstraint) ] # Update the workgroup constraints for aliases sources. for constraint in self.workgroup_constraints: @@ -388,7 +388,9 @@ def _trace_and_get_kernel_signature( # Determine grid shape. self.grid_type.dims = [1, 1, 1] max_workgroup_dim = 2 - aliases = [x.source for x in self.constraints if isinstance(x, SymbolicAlias)] + aliases = [ + x.source for x in self.constraints if isinstance(x, SymbolicConstraint) + ] for constraint in self.workgroup_constraints: if constraint.dim in aliases: continue