Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC] Rename SymbolicAlias -> SymbolicConstraints #428

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.utils._pytree as pytree
from collections import namedtuple

from .symbolic_constraints import SymbolicAlias
from .symbolic_constraints import SymbolicConstraint
from ..compiler.ir import (
Attribute,
DenseElementsAttr,
Expand Down
14 changes: 8 additions & 6 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
WorkgroupConstraint,
)
from .assumptions import Assumption
from .symbolic_constraints import SymbolicAlias
from .symbolic_constraints import SymbolicConstraint
from .._support.tracing import CapturedTrace, IndexingContext
from .._support.indexing import IndexSymbol, IndexSequence
from ..lang.global_symbols import *
Expand Down Expand Up @@ -469,7 +469,7 @@ def set_thread_independent_index(
constraints = [
c
for c in constraints
if not isinstance(c, (HardwareConstraint, Assumption, SymbolicAlias))
if not isinstance(c, (HardwareConstraint, Assumption, SymbolicConstraint))
]

index = {}
Expand Down Expand Up @@ -628,7 +628,7 @@ def should_update_index(
source: CustomOp,
source_index: dict[IndexSymbol, IndexSequence],
source_vector_shapes: dict[IndexSymbol, int],
symbolic_constraints: list[SymbolicAlias],
symbolic_constraints: list[SymbolicConstraint],
):
# Get symbolic shape without any aliased variables.
aliased_dims = [x.source for x in symbolic_constraints]
Expand Down Expand Up @@ -656,7 +656,9 @@ def should_update_index(
return True


def append_aliased_shapes(source: CustomOp, symbolic_constraints: list[SymbolicAlias]):
def append_aliased_shapes(
source: CustomOp, symbolic_constraints: list[SymbolicConstraint]
):
"""
Append the aliased shapes to the vector shapes of the source, if they
are present in the source index.
Expand All @@ -677,7 +679,7 @@ def propagate_index(
workgroup_constraints: list[WorkgroupConstraint],
mma_index: dict[MMA, dict[IndexSymbol, int]],
visited: set[CustomOp],
symbolic_constraints: list[SymbolicAlias],
symbolic_constraints: list[SymbolicConstraint],
):
"""
Propagate the index and vector shapes through the graph
Expand Down Expand Up @@ -736,7 +738,7 @@ def set_thread_dependent_index(
workgroup_constraints = [
c for c in constraints if isinstance(c, WorkgroupConstraint)
]
symbolic_constraints = [c for c in constraints if isinstance(c, SymbolicAlias)]
symbolic_constraints = [c for c in constraints if isinstance(c, SymbolicConstraint)]
for source in sources:
visited = visited.union(set([x for x in sources]))
visited.remove(source)
Expand Down
4 changes: 2 additions & 2 deletions iree/turbine/kernel/wave/symbolic_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@


@dataclass
class SymbolicAlias:
class SymbolicConstraint:
"""
A constraint of the form `tkw.SymbolicConstraint(K, SYMBOLIC_K)` specifies
that the relationship between the source and target symbols is given by
source = source_to_target(target).

SymbolicAliases are modeled in the compiler as additional workgroup, wave,
SymbolicConstraintes are modeled in the compiler as additional workgroup, wave,
and tiling constraints that are derived from the source. They are ignored
during expansion and utilize the same workgroup and wave ids as the
target symbol.
Expand Down
46 changes: 46 additions & 0 deletions iree/turbine/kernel/wave/templates/attention_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from dataclasses import dataclass
from typing import Optional

import iree.turbine.kernel.lang as tkl


@dataclass
class AttentionShape:
num_query_heads: int
num_kv_heads: int
head_size: int
head_size_kv: int
# -----------------------
# Prefill specific
num_seqs: Optional[int] = None
max_seq_len: Optional[int] = None
total_seq_len: Optional[int] = None
# -----------------------
# Vanilla attention
query_seq_len: Optional[int] = None
kv_seq_len: Optional[int] = None


# Commonly-used attention symbols
H = tkl.sym.H
H_Q = tkl.sym.H
H_KV = tkl.sym.H
N_Q = tkl.sym.N_D
N_KV = tkl.sym.N_KV
D_Q = tkl.sym.D_Q
D_KV = tkl.sym.D_KV

BLOCK_H = tkl.sym.BLOCK_H
BLOCK_H_Q = tkl.sym.BLOCK_H
BLOCK_H_KV = tkl.sym.BLOCK_H
BLOCK_N_Q = tkl.sym.BLOCK_N_D
BLOCK_N_KV = tkl.sym.BLOCK_N_KV
BLOCK_D_Q = tkl.sym.BLOCK_D_Q
BLOCK_D_KV = tkl.sym.BLOCK_D_KV

4 changes: 2 additions & 2 deletions iree/turbine/kernel/wave/templates/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
get_mfma_load_elems_per_thread,
get_mfma_store_elems_per_thread,
)
from ..symbolic_constraints import SymbolicAlias
from ..symbolic_constraints import SymbolicConstraint
import sympy
from enum import Enum
import math
Expand Down Expand Up @@ -64,7 +64,7 @@ def phase_0_constraints():
constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / K_WAVES)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 3)]
constraints += [
SymbolicAlias(U, K2, lambda x: sympy.ceiling(x / (BLOCK_K2 / K_WAVES)))
SymbolicConstraint(U, K2, lambda x: sympy.ceiling(x / (BLOCK_K2 / K_WAVES)))
]
vector_shapes = {B: 0}
waves_per_block = (M_WAVES, N_WAVES, K_WAVES)
Expand Down
12 changes: 7 additions & 5 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.fx as fx
import inspect

from .symbolic_constraints import SymbolicAlias
from .symbolic_constraints import SymbolicConstraint

from ..compiler import builder, dispatch_codegen, kernel_codegen, host_codegen
from ..compiler.ir import Context, Operation
Expand Down Expand Up @@ -141,7 +141,7 @@ def symbolic_constraints(self) -> list[HardwareConstraint]:
return [
constraint
for constraint in self.constraints
if isinstance(constraint, SymbolicAlias)
if isinstance(constraint, SymbolicConstraint)
]

def _trace(self) -> CapturedTrace:
Expand Down Expand Up @@ -230,7 +230,7 @@ def get_workgroup_dims(self) -> list[int]:
"""
# Ignore aliased variables. They will be handled separately.
aliased_dims = [
x.source for x in self.constraints if isinstance(x, SymbolicAlias)
x.source for x in self.constraints if isinstance(x, SymbolicConstraint)
]
workgroup_dims = {
x.workgroup_dim: x
Expand All @@ -246,7 +246,7 @@ def update_aliased_workgroup_constraints(
This function updates the wg_dim for aliased workgroup constraints.
"""
aliased_dims = [
x.source for x in self.constraints if isinstance(x, SymbolicAlias)
x.source for x in self.constraints if isinstance(x, SymbolicConstraint)
]
# Update the workgroup constraints for aliases sources.
for constraint in self.workgroup_constraints:
Expand Down Expand Up @@ -388,7 +388,9 @@ def _trace_and_get_kernel_signature(
# Determine grid shape.
self.grid_type.dims = [1, 1, 1]
max_workgroup_dim = 2
aliases = [x.source for x in self.constraints if isinstance(x, SymbolicAlias)]
aliases = [
x.source for x in self.constraints if isinstance(x, SymbolicConstraint)
]
for constraint in self.workgroup_constraints:
if constraint.dim in aliases:
continue
Expand Down
Loading