diff --git a/iree/turbine/kernel/wave/templates/extend_attention.py b/iree/turbine/kernel/wave/templates/extend_attention.py index 6daea5f71..d87471582 100644 --- a/iree/turbine/kernel/wave/templates/extend_attention.py +++ b/iree/turbine/kernel/wave/templates/extend_attention.py @@ -17,6 +17,7 @@ import math import torch from typing import Optional +import sympy def get_extend_attention_kernel( @@ -231,6 +232,11 @@ def first_loop( res_max, res_sum, res_mm = first_loop + if is_causal: + seq_len_extend = tkw.apply_expr( + seq_len_extend, + lambda x: sympy.Min(x, (WORKGROUP_0 + 1) * SEQ_TILE_SIZE), + ) tkw.set_symbol(N_KV, seq_len_extend) @tkw.reduction(N_KV, init_args=[res_max, res_sum, res_mm])