From a5a181897e893a238f62f1896730e911dd9ee59e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 10 Feb 2025 14:37:16 -0800 Subject: [PATCH] debug: gather is wrong --- playground/attention_with_rpe_template.py | 29 ++++++++++++++++-- playground/attention_with_rpe_test.py | 36 ++++++++++++++--------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/playground/attention_with_rpe_template.py b/playground/attention_with_rpe_template.py index a4c78c6aa..5c7fd72da 100644 --- a/playground/attention_with_rpe_template.py +++ b/playground/attention_with_rpe_template.py @@ -86,6 +86,14 @@ def get_vanilla_attention_kernel( outputs={K2: i}, ) + # d = tkw.IndexMapping.dynamic_val(0) + # dynamic_mapping = tkw.IndexMapping( + # num_iterators=3, + # inputs = {B: d}, + # outputs = {B: i, M: j, K2: k}, + # dynamic_val_mappings = {B: i, M: j, K2: k} + # ) + use_t5_rpe = max_context_length is not None if use_t5_rpe: rpe_layout = tkl.MemoryLayout( @@ -103,6 +111,7 @@ def base_attention( # TODO: if not use_t5_rpe, this will DCE; atm DCE on blockargs crashes. rpe: tkl.Memory[K2, GLOBAL_ADDRESS_SPACE, tkl.f32, rpe_layout], c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + debug_out: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], ): c_reg = tkl.Register[B, N, M, tkl.f32](0.0) init_sum = tkl.Register[B, M, tkl.f32](0.0) @@ -152,11 +161,22 @@ def repeat( # to do bucketing; atm it is bucketing of size 1. # min/max variant - # idx = tkw.maximum(i - j, ZERO) - # idx = tkw.minimum(idx, MAX) + idx = tkw.maximum(i - j, ZERO) + idx = tkw.minimum(idx, MAX) # select variant. - idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO) + # idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO) + + idx = tkw.broadcast(idx, target_shape=[B,M,K2]) + + ### alternative + # rpe_reg = tkw.read( + # rpe, + # mapping=dynamic_mapping, + # mapping_dynamic_vals=(idx,), + # elements_per_thread=LOAD_ELEMS_PER_THREAD_QK, + # ) + ### # 3. Read indirect into the 1-D rpe array via offset_mapping. tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2] @@ -165,6 +185,9 @@ def repeat( mapping=offset_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK, ) + rpe_reg = tkw.broadcast(rpe_reg, target_shape=[B,M,K2]) + tkw.write(rpe_reg, debug_out, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + # tkw.write(tkw.cast(idx, tkl.f32), debug_out, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) # 4. Tadaaaa. x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.f32) diff --git a/playground/attention_with_rpe_test.py b/playground/attention_with_rpe_test.py index f9fc3ebef..7f166aedc 100644 --- a/playground/attention_with_rpe_test.py +++ b/playground/attention_with_rpe_test.py @@ -37,14 +37,14 @@ ### TKW Harness def run(fun: Callable, hparams, *args) -> Any: - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA] - ) as prof: + # with torch.profiler.profile( + # activities=[torch.profiler.ProfilerActivity.CUDA] + # ) as prof: with torch.no_grad(): # Disable gradient calculations with TestLaunchContext( hparams, canonicalize=True, - compile_config={"print-ir-after": "all"}, + # compile_config={"print-ir-after": "all"}, run=True, run_config=get_default_run_config(), run_bench=False, @@ -53,11 +53,11 @@ def run(fun: Callable, hparams, *args) -> Any: ): fun(*args) - print( - prof.key_averages(group_by_input_shape=True).table( - sort_by="self_cuda_time_total", row_limit=10 - ) - ) + # print( + # prof.key_averages(group_by_input_shape=True).table( + # sort_by="self_cuda_time_total", row_limit=10 + # ) + # ) ################################################################################# @@ -90,7 +90,7 @@ def run(fun: Callable, hparams, *args) -> Any: # T5 RPE INIT VALS ################################################################################# # T5 RPE parameter -max_context_length = 33 +max_context_length = 30 # Applied pre-softmax on the MMA'ed result so f32. # Provision more room for clipping and adding 0 at the boundaries. @@ -100,6 +100,7 @@ def run(fun: Callable, hparams, *args) -> Any: rpe[0] = 0 rpe[max_context_length + 1] = 0 +tmp_out = device_zeros(shape.num_query_heads, shape.query_seq_len, shape.kv_seq_len, dtype=torch.float32) def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype): positions = to_default_device(torch.arange(sequence_length)) @@ -118,6 +119,10 @@ def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype dtype=tkw_attention_with_rpe_output.dtype, ) +# print(rpe) +print(rpe_cond.shape) +print(rpe_cond) + ################################################################################# # TKW BASE ATTENTION ################################################################################# @@ -135,21 +140,24 @@ def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype ) -def attention_with_rpe(tq, tk, tv, trpe, toutput): - mb = tkw_attention_with_rpe(tq, tk, tv, trpe, toutput) +def attention_with_rpe(*args): + mb = tkw_attention_with_rpe(*args) print(mb.module_op) - run( attention_with_rpe, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), - rpe * log2e, + rpe, # * log2e, tkw_attention_with_rpe_output, + tmp_out ) +print(tmp_out.shape) +print(tmp_out[0, :, :]) + ### Reference version ( tkw_attention,