-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split Hopper MMA by warp-tile before instruction tile #3642
base: main
Are you sure you want to change the base?
Conversation
!test |
The bank conflict came from stmatrix scheduling which needs to be updated. I will do that in a separate PR. For now, I've disabled smem epilogue in the included test. |
!test |
When I manually disable stmatrix but keep TMA store, I still hit a bank conflict and misaligned address in the smem read when doing the TMA store. The epilogue looks like this: asm volatile("wgmma.commit_group.sync.aligned;\n");
asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
__syncthreads();
#pragma unroll
for(nvfuser_index_t i50 = 0; i50 < 16; ++i50) {
nvfuser_index_t i51;
i51 = 4 * i50;
#pragma unroll
for(nvfuser_index_t i52 = 0; i52 < 2; ++i52) {
nvfuser_index_t i53;
i53 = i51 + (2 * i52);
Array<__half, 2, 2> T6;
#pragma unroll
for(nvfuser_index_t i54 = 0; i54 < 2; ++i54) {
T6[i54]
= __float2half(T2[(i53 + i54)]);
}
loadGeneric<__half, 2>( &T7[(i17 + (128 * i52))], &T6[0]);
}
__syncthreads();
asm volatile("fence.proxy.async;\n");
if (b24) {
Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr19, (Array<nvfuser_index_t, 2, 1>{(i20 + (8 * i50)), i21}) }), i18);
}
__syncthreads();
asm volatile("cp.async.bulk.commit_group;\n");
asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");
}
asm volatile("cp.async.bulk.commit_group;\n");
asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory"); The misaligned read happens with threadIdx.y = 3;
i11 = ((nvfuser_index_t)threadIdx.y) / 2; // =1
i12 = 2048 * i11; // =2048
i14 = ((nvfuser_index_t)threadIdx.y) % 2; // =1
i18 = (toSmem(T7) + i12) + (16 * i14); // =toSmem(T7) + 2064
|
mma result before this PR:
And after this PR:
|
Note that I can enable smem epilogue and the test passes if I use |
I think this covers the motivation for #3616
// K dimension is present for mma_result | ||
tv->split(-1, params_->tile_sizes.warp_tile.k); | ||
tv->split(-1, getK(params_->mma_macro)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rdspring1 is this enough or is #3616 still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is all that is required for scheduler changes.
// size | ||
// Original: [..., M, N(, K)] | ||
// We split this into warp tiles then instruction tiles | ||
if (is_mma_result) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: since there is no code in common between these branches, we should split this into two separate functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to remove this limitation to handle all matmul parameter configurations?
CTA tile must match warp tile K dimension for Hopper matmul but found MatMulTileOptions: warp tile [64, 256, 32], CTA tile [128, 256, 64]
// K dimension is present for mma_result | ||
tv->split(-1, params_->tile_sizes.warp_tile.k); | ||
tv->split(-1, getK(params_->mma_macro)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is all that is required for scheduler changes.
I see |
Currently we ignore the warp tile parameter when scheduling Hopper matmuls (see #3636). This PR introduces a test with different CTA, warp, and instruction tiles and modifies the Hopper scheduler to split by warp tile in addition to instruction tile. Note that the instruction tile split results in two serial loop domain so we wind up executing multiple mma instructions in each main loop. In the included example,
warp_tile
is 64, 128, 16 and the macro isHopper_64_8_16
. In this case, there are 128/8 = 16 instruction tiles per warp tile so the generated main loop looks like this:Fixes #3636