Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Bug: Index mismatch between op and users creates multiple problems #410

Open
GMNGeoffrey opened this issue Jan 23, 2025 · 1 comment

Comments

@GMNGeoffrey
Copy link
Contributor

GMNGeoffrey commented Jan 23, 2025

I've run into a few issues that occur when the index for an op and the ops that use it differ. I think these come up due to my workarounds for #381 and #384. I haven't come up with a minimal reproducer yet, but it shows up in my kernel for the backward pass of flash attention 2 (specifically when computing dk, in this case).

s_acc = tkl.Register[B, M_qs, K2_kvs, tkl.f32](0.0)
log2e = tkl.Register[B, M_qs, K2_kvs, tkl.f16](1.44269504089)
s_ij = tkw.mma(q_i, k_j, s_acc)
tkw.write(s_ij, s, elements_per_thread=MFMA_OUTPUT_ELS_PER_THREAD)
# a no-op permute here gets past a compiler error resolving node
# indices. I think it just hides the K1 dimension from the index.
s_ij = tkw.permute(s_ij, [B, M_qs, K2_kvs])
lse_i = tkw.read(lse, elements_per_thread=MFMA_OUTPUT_ELS_PER_THREAD)
p_ij = tkw.exp2(log2e * (tkw.cast(s_ij, tkl.f16) - lse_i))
tkw.write(p_ij, p, elements_per_thread=MFMA_OUTPUT_ELS_PER_THREAD)
p_ij = tkw.permute(p_ij, [B, K2_kvs, M_qs])
dp_acc = tkl.Register[B, M_qs, K2_kvs, tkl.f32](0.0)
v_j = tkw.read(v, elements_per_thread=MFMA_INPUT_ELS_PER_THREAD)
do_i = tkw.read(do, elements_per_thread=MFMA_INPUT_ELS_PER_THREAD)
dp_ij = tkw.mma(do_i, v_j, dp_acc)
# This no-op permute fixes a compiler error by hiding the N index of
# the mma from the cast that uses it. Otherwise, the cast operation
# fails to update the op it uses during expansion.
dp_ij = tkw.permute(dp_ij, [B, M_qs, K2_kvs])
tkw.write(dp_ij, dp, elements_per_thread=MFMA_OUTPUT_ELS_PER_THREAD)
D_i = tkw.read(D, elements_per_thread=MFMA_OUTPUT_ELS_PER_THREAD)
dp_ij_sub = tkw.cast(dp_ij, tkl.f16) - D_i
tkw.write(dp_ij_sub, dp_sub, elements_per_thread=MFMA_OUTPUT_ELS_PER_THREAD)
dp_ij_sub = tkw.permute(dp_ij_sub, [B, K2_kvs, M_qs])

You'll see I've got two permutes in there that just permute an mma output to the shape it already was. In the first case this resolves an error complaining that it can only resolve thread shape discrepancies when one of the shapes is 1:

E   NotImplementedError: Currently only support resolving discrepancies when one of the shapes is 1.
E   binary_op=sub
E   lhs=CastOp(graph=<torch.fx.graph.Graph object at 0x74f10134b2e0>, fx_node=cast, tkw_op_name='cast', _tracing_function=<bound method define_op.<locals>.decorator.<locals>.new_function of ...>, arg=mma, dtype=DataType(f16))
E   lhs_index={B: BLOCK_B*(Mod($WG2, 2)) : 1 : 1, M: $ARGM*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, K1: $WG1*BLOCK_K1 + 4*floor((Mod($T0, 64))/16) : 4 : 1, K2: $WG0*BLOCK_K2 + Mod($T0, 16) : 1 : 1}
E   lhs_dim=K1
E   lhs_size=4
E   lhs.type.symbolic_shape=(B, M, K2)
E   rhs=Read(graph=<torch.fx.graph.Graph object at 0x74f10134b2e0>, fx_node=read_2, tkw_op_name='read', _tracing_function=<bound method define_op.<locals>.decorator.<locals>.new_function of ...>, memory=lse, elements_per_thread=MFMA_OUTPUT_ELS_PER_THREAD, mapping=None, mapping_dynamic_vals=(), _write_dependency=None)
E   rhs_index={B: BLOCK_B*(Mod($WG2, 2)) : 1 : 1, M: $ARGM*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16}
E   rhs_dim=M
E   rhs_size=4
E   rhs.type.symbolic_shape=(B, M)

Here the cast op inherits the complex index of its mma operand which includes B, M, K1, and K2, but if I put the permute there, then the permute only has B, M, K2 in its index and compilation succeeds. Note that no actual MLIR is generated for this operation: the cast directly uses the MMA output.

A related, but more pernicious bug occurs if the N dimension is 32. Now the MMA to compute dp is expanded along the N dimension, but the expansion of the cast operation on dp doesn't update the cast operand and just uses the original MMA, which ends up meaning that it uses the accumulator in between the two MMA operations. Again, a no-op permute saves the day here. The permute does update its argument for whatever reason. Again, this doesn't actually result in any MLIR.

The bug in expansion looks a lot like #374, where the same thing happened because the write not getting expanded at all. I wonder if it would be a good idea to poison the originals of nodes that have gotten expanded. Is using them again always an error?

FYI @harsh-nod

@GMNGeoffrey GMNGeoffrey changed the title Index mismatch between op and users creates multiple problems [TKW] Bug: Index mismatch between op and users creates multiple problems Jan 23, 2025
@GMNGeoffrey
Copy link
Contributor Author

GMNGeoffrey commented Jan 23, 2025

I suspect that Reshapes work here where casts don't because of this special casing:

nevermind

Nevermind Reshape != Permute at this point

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant