You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've run into a few issues that occur when the index for an op and the ops that use it differ. I think these come up due to my workarounds for #381 and #384. I haven't come up with a minimal reproducer yet, but it shows up in my kernel for the backward pass of flash attention 2 (specifically when computing dk, in this case).
You'll see I've got two permutes in there that just permute an mma output to the shape it already was. In the first case this resolves an error complaining that it can only resolve thread shape discrepancies when one of the shapes is 1:
E NotImplementedError: Currently only support resolving discrepancies when one of the shapes is 1.
E binary_op=sub
E lhs=CastOp(graph=<torch.fx.graph.Graph object at 0x74f10134b2e0>, fx_node=cast, tkw_op_name='cast', _tracing_function=<bound method define_op.<locals>.decorator.<locals>.new_function of ...>, arg=mma, dtype=DataType(f16))
E lhs_index={B: BLOCK_B*(Mod($WG2, 2)) : 1 : 1, M: $ARGM*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, K1: $WG1*BLOCK_K1 + 4*floor((Mod($T0, 64))/16) : 4 : 1, K2: $WG0*BLOCK_K2 + Mod($T0, 16) : 1 : 1}
E lhs_dim=K1
E lhs_size=4
E lhs.type.symbolic_shape=(B, M, K2)
E rhs=Read(graph=<torch.fx.graph.Graph object at 0x74f10134b2e0>, fx_node=read_2, tkw_op_name='read', _tracing_function=<bound method define_op.<locals>.decorator.<locals>.new_function of ...>, memory=lse, elements_per_thread=MFMA_OUTPUT_ELS_PER_THREAD, mapping=None, mapping_dynamic_vals=(), _write_dependency=None)
E rhs_index={B: BLOCK_B*(Mod($WG2, 2)) : 1 : 1, M: $ARGM*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16}
E rhs_dim=M
E rhs_size=4
E rhs.type.symbolic_shape=(B, M)
Here the cast op inherits the complex index of its mma operand which includes B, M, K1, and K2, but if I put the permute there, then the permute only has B, M, K2 in its index and compilation succeeds. Note that no actual MLIR is generated for this operation: the cast directly uses the MMA output.
A related, but more pernicious bug occurs if the N dimension is 32. Now the MMA to compute dp is expanded along the N dimension, but the expansion of the cast operation on dp doesn't update the cast operand and just uses the original MMA, which ends up meaning that it uses the accumulator in between the two MMA operations. Again, a no-op permute saves the day here. The permute does update its argument for whatever reason. Again, this doesn't actually result in any MLIR.
The bug in expansion looks a lot like #374, where the same thing happened because the write not getting expanded at all. I wonder if it would be a good idea to poison the originals of nodes that have gotten expanded. Is using them again always an error?
The text was updated successfully, but these errors were encountered:
GMNGeoffrey
changed the title
Index mismatch between op and users creates multiple problems
[TKW] Bug: Index mismatch between op and users creates multiple problems
Jan 23, 2025
I've run into a few issues that occur when the index for an op and the ops that use it differ. I think these come up due to my workarounds for #381 and #384. I haven't come up with a minimal reproducer yet, but it shows up in my kernel for the backward pass of flash attention 2 (specifically when computing dk, in this case).
iree-turbine/tests/kernel/wave/attention/backward_attention_test.py
Lines 872 to 897 in 4f18a93
You'll see I've got two permutes in there that just permute an mma output to the shape it already was. In the first case this resolves an error complaining that it can only resolve thread shape discrepancies when one of the shapes is 1:
Here the cast op inherits the complex index of its mma operand which includes
B
,M
,K1
, andK2
, but if I put the permute there, then the permute only hasB
,M
,K2
in its index and compilation succeeds. Note that no actual MLIR is generated for this operation: the cast directly uses the MMA output.A related, but more pernicious bug occurs if the N dimension is 32. Now the MMA to compute dp is expanded along the N dimension, but the expansion of the cast operation on
dp
doesn't update the cast operand and just uses the original MMA, which ends up meaning that it uses the accumulator in between the two MMA operations. Again, a no-op permute saves the day here. The permute does update its argument for whatever reason. Again, this doesn't actually result in any MLIR.The bug in expansion looks a lot like #374, where the same thing happened because the write not getting expanded at all. I wonder if it would be a good idea to poison the originals of nodes that have gotten expanded. Is using them again always an error?
FYI @harsh-nod
The text was updated successfully, but these errors were encountered: