Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug RPE with MMA 32x32x8 #502

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,11 +1257,14 @@ def transform_index_backwards(
subs = {
k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys())
}
return {
print(f"subs {subs}")
d = {
k: IndexSequence.from_expr(mapping[k], subs)
for k in get_custom(arg).type.symbolic_shape
if k in mapping
}
print(f"d {d}")
return d

return index

Expand Down
2 changes: 2 additions & 0 deletions iree/turbine/kernel/wave/codegen/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,14 @@ def _get_const(val):
# Substitute in frozen vars to simplify expression.
if not isinstance(expr, sympy.Expr):
expr = sympy.sympify(expr)
print(f"\n\nexpr {expr} dynamics {dynamics}")
expr = expr.subs(idxc.subs)
# Why affine, for now simply create indexing expressions.
# This can easily be adapted to affine expressions later.
select_stack = []
if isinstance(expr, sympy.Piecewise):
assert len(expr.args) == 2 and expr.args[1][1], f"Unsupported piecewise {expr}"
print(f"\n\nexpr {expr} dynamics {dynamics}")
for term in sympy.postorder_traversal(expr):
match term:
case sympy.Symbol():
Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/wave/codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _build_start_indices(
src_indices: dict[IndexExpr, IndexSequence | IndexExpr],
dynamic_values: dict[IndexExpr, Any] = {},
) -> list[OpResult]:
print(f"dynamic_values {dynamic_values}")
return [
gen_sympy_index(add_emitter_subs(emitter, dynamic_values), i)
for i in _get_start_indices(src_indices)
Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def has_gpr_offsets(node: fx.Node) -> bool:
]
assert len(dim_with_gpr_offsets) == 1, "Expected only 1-Dim has gpr offsets"
gpr_offset_dim, gpr_offset_expr = dim_with_gpr_offsets[0]
print(f"elements_per_thread {elements_per_thread}")
gpr_offsets = [
gpr_offset_expr.subs({GPR_NUM: i}) for i in range(elements_per_thread)
]
Expand Down
8 changes: 3 additions & 5 deletions iree/turbine/kernel/wave/templates/t5_rpe_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple

import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand All @@ -22,7 +22,7 @@

def get_t5_rpe_attention_kernel(
shape: AttentionShape,
mfma_variant: MMAType,
mfma_variant: Tuple[MMAType],
dynamic_dims: bool,
max_context_length: int,
):
Expand Down Expand Up @@ -123,9 +123,7 @@ def repeat(
# the partial softmax should be equivalent.
i = tkw.self_index(M, tkl.i64, elements_per_thread=1)
i = tkw.broadcast(i, target_shape=[B, M, K2])
j = tkw.self_index(
K2, tkl.i64, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK
)
j = tkw.self_index(K2, tkl.i64, elements_per_thread=4)
rpe_reg = tkw.read(
rpe,
mapping=offset_mapping,
Expand Down
59 changes: 32 additions & 27 deletions tests/kernel/wave/attention/t5_rpe_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from iree.turbine.kernel.wave.templates.t5_rpe_attention import (
get_t5_rpe_attention_kernel,
)
from ..common.utils import (
require_e2e,
require_cdna3,
enable_scheduling_barriers,
)

# from ..common.utils import (
# require_e2e,
# require_cdna3,
# enable_scheduling_barriers,
# )
from typing import Tuple

shapes = [(128, 128, 128, 128, 128, 128)]
Expand Down Expand Up @@ -83,19 +84,19 @@ def create_inputs(


# TODO: Debug why failing numerics on MI250.
@require_e2e
@require_cdna3
@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize(
"mfma_variant",
[(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)],
)
# @require_e2e
# @require_cdna3
# @pytest.mark.parametrize("shape", shapes)
# @pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize(
# "mfma_variant",
# [(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)],
# )
def test_t5_rpe_attention(
shape: Tuple[int],
dtype: torch.dtype,
mfma_variant: MMAType,
request,
mfma_variant: Tuple[MMAType],
# request,
):
torch.manual_seed(0)
shape = AttentionShape(
Expand All @@ -119,16 +120,15 @@ def test_t5_rpe_attention(

hyperparams.update(get_default_scheduling_params())
config = get_default_run_config()
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
if run_bench:
config["benchmark_batch_size"] = 10
config["benchmark_repetitions"] = 3
if dump_perf is not None:
perf_filename = request.node.name + ".json"
config["benchmark_results_file"] = os.path.join(
dump_perf, "tk_" + perf_filename
)
# run_bench = request.config.getoption("--runperf")
# dump_perf = request.config.getoption("--dump-perf-files-path")
# if run_bench:
# config["benchmark_batch_size"] = 10
# config["benchmark_repetitions"] = 3
# if dump_perf is not None:
# perf_filename = request.node.name + ".json"
# config["benchmark_results_file"] = os.path.join(
# dump_perf, "tk_" + perf_filename)

log2e = 1.44269504089
dk_sqrt = math.sqrt(1.0 / shape.head_size)
Expand All @@ -143,9 +143,9 @@ def test_t5_rpe_attention(
hyperparams,
canonicalize=True,
run=True,
run_bench=run_bench,
# run_bench=run_bench,
run_config=config,
use_scheduling_barriers=enable_scheduling_barriers,
# use_scheduling_barriers=enable_scheduling_barriers,
):
output = device_zeros(output_shape, dtype=torch.float32)
# TODO: Add scaling of QK and t5_rpe as part of kernel.
Expand All @@ -161,3 +161,8 @@ def test_t5_rpe_attention(
)

validate_accuracy(query, key, value, rpe, output)


test_t5_rpe_attention(
shapes[0], torch.float16, (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16)
)
Loading