Skip to content

Commit

Permalink
debug: gather is wrong
Browse files Browse the repository at this point in the history
  • Loading branch information
ftynse committed Feb 10, 2025
1 parent 8882a50 commit a5a1818
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
29 changes: 26 additions & 3 deletions playground/attention_with_rpe_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def get_vanilla_attention_kernel(
outputs={K2: i},
)

# d = tkw.IndexMapping.dynamic_val(0)
# dynamic_mapping = tkw.IndexMapping(
# num_iterators=3,
# inputs = {B: d},
# outputs = {B: i, M: j, K2: k},
# dynamic_val_mappings = {B: i, M: j, K2: k}
# )

use_t5_rpe = max_context_length is not None
if use_t5_rpe:
rpe_layout = tkl.MemoryLayout(
Expand All @@ -103,6 +111,7 @@ def base_attention(
# TODO: if not use_t5_rpe, this will DCE; atm DCE on blockargs crashes.
rpe: tkl.Memory[K2, GLOBAL_ADDRESS_SPACE, tkl.f32, rpe_layout],
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
debug_out: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[B, N, M, tkl.f32](0.0)
init_sum = tkl.Register[B, M, tkl.f32](0.0)
Expand Down Expand Up @@ -152,11 +161,22 @@ def repeat(
# to do bucketing; atm it is bucketing of size 1.

# min/max variant
# idx = tkw.maximum(i - j, ZERO)
# idx = tkw.minimum(idx, MAX)
idx = tkw.maximum(i - j, ZERO)
idx = tkw.minimum(idx, MAX)

# select variant.
idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO)
# idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO)

idx = tkw.broadcast(idx, target_shape=[B,M,K2])

### alternative
# rpe_reg = tkw.read(
# rpe,
# mapping=dynamic_mapping,
# mapping_dynamic_vals=(idx,),
# elements_per_thread=LOAD_ELEMS_PER_THREAD_QK,
# )
###

# 3. Read indirect into the 1-D rpe array via offset_mapping.
tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2]
Expand All @@ -165,6 +185,9 @@ def repeat(
mapping=offset_mapping,
elements_per_thread=LOAD_ELEMS_PER_THREAD_QK,
)
rpe_reg = tkw.broadcast(rpe_reg, target_shape=[B,M,K2])
tkw.write(rpe_reg, debug_out, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
# tkw.write(tkw.cast(idx, tkl.f32), debug_out, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)

# 4. Tadaaaa.
x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.f32)
Expand Down
36 changes: 22 additions & 14 deletions playground/attention_with_rpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@

### TKW Harness
def run(fun: Callable, hparams, *args) -> Any:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA]
) as prof:
# with torch.profiler.profile(
# activities=[torch.profiler.ProfilerActivity.CUDA]
# ) as prof:
with torch.no_grad(): # Disable gradient calculations
with TestLaunchContext(
hparams,
canonicalize=True,
compile_config={"print-ir-after": "all"},
# compile_config={"print-ir-after": "all"},
run=True,
run_config=get_default_run_config(),
run_bench=False,
Expand All @@ -53,11 +53,11 @@ def run(fun: Callable, hparams, *args) -> Any:
):
fun(*args)

print(
prof.key_averages(group_by_input_shape=True).table(
sort_by="self_cuda_time_total", row_limit=10
)
)
# print(
# prof.key_averages(group_by_input_shape=True).table(
# sort_by="self_cuda_time_total", row_limit=10
# )
# )


#################################################################################
Expand Down Expand Up @@ -90,7 +90,7 @@ def run(fun: Callable, hparams, *args) -> Any:
# T5 RPE INIT VALS
#################################################################################
# T5 RPE parameter
max_context_length = 33
max_context_length = 30

# Applied pre-softmax on the MMA'ed result so f32.
# Provision more room for clipping and adding 0 at the boundaries.
Expand All @@ -100,6 +100,7 @@ def run(fun: Callable, hparams, *args) -> Any:
rpe[0] = 0
rpe[max_context_length + 1] = 0

tmp_out = device_zeros(shape.num_query_heads, shape.query_seq_len, shape.kv_seq_len, dtype=torch.float32)

def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype):
positions = to_default_device(torch.arange(sequence_length))
Expand All @@ -118,6 +119,10 @@ def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype
dtype=tkw_attention_with_rpe_output.dtype,
)

# print(rpe)
print(rpe_cond.shape)
print(rpe_cond)

#################################################################################
# TKW BASE ATTENTION
#################################################################################
Expand All @@ -135,21 +140,24 @@ def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype
)


def attention_with_rpe(tq, tk, tv, trpe, toutput):
mb = tkw_attention_with_rpe(tq, tk, tv, trpe, toutput)
def attention_with_rpe(*args):
mb = tkw_attention_with_rpe(*args)
print(mb.module_op)


run(
attention_with_rpe,
hyperparams,
q * dk_sqrt * log2e,
k,
v.permute([0, 2, 1]),
rpe * log2e,
rpe, # * log2e,
tkw_attention_with_rpe_output,
tmp_out
)

print(tmp_out.shape)
print(tmp_out[0, :, :])

### Reference version
(
tkw_attention,
Expand Down

0 comments on commit a5a1818

Please sign in to comment.