-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LinCombPerRowBias & EVT Changes #152
base: sycl-develop
Are you sure you want to change the base?
LinCombPerRowBias & EVT Changes #152
Conversation
4720c36
to
bdcb4bb
Compare
bdcb4bb
to
a704294
Compare
8d7f032
to
600cbf6
Compare
As a precursor to properly testing #152, I have adapted the gemm_testbed_3x.hpp infrastructure slightly for Xe and added a single test case for LinCombEltAct, as that's something we know should work. --------- Co-authored-by: Finlay <[email protected]>
Not totally sure these are correct yet, but it's compiling...
17df726
to
b5db222
Compare
@@ -262,7 +262,7 @@ class GemmUniversal< | |||
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; | |||
epilogue( | |||
problem_shape_MNKL, | |||
subgroup_shape, | |||
subgroup_shape, // TODO(joe): Inconsistency here w/ blk_coord_mnkl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aacostadiaz the subgroup shape is being passed in here instead of the WorkgroupTileShape
. However, the coordinate blk_coord_mnkl
is the work-group coordinate. This doesn't matter for this PR because I'm computing things inside xe_epilogue.hpp
but it feels inconsistent. Potentially something to address as part of your ongoing work on the coordinates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that makes sense. I can look into that as part of the fix for coordinates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe change to TODO(codeplay):
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!!!
I left some minor comments. I'm not sure if the verification in the example is checking the right thing. Could you please double check that?
// Check if output from CUTLASS kernel and reference kernel are equal or not | ||
bool passed = cutlass::reference::device::BlockCompareEqual( | ||
block_ref_D.get(), block_D.get(), block_D.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this checking the output with the epilogue or just the output of GEMM?
@@ -262,7 +262,7 @@ class GemmUniversal< | |||
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; | |||
epilogue( | |||
problem_shape_MNKL, | |||
subgroup_shape, | |||
subgroup_shape, // TODO(joe): Inconsistency here w/ blk_coord_mnkl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that makes sense. I can look into that as part of the fix for coordinates.
test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp
Outdated
Show resolved
Hide resolved
using ElementOutput = typename CollectiveEpilogue::ElementOutput; | ||
using ElementCompute = typename CollectiveEpilogue::ElementCompute; | ||
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; | ||
using ElementBias = typename CollectiveEpilogue::ThreadEpilogueOp::ElementBias; //TODO(joe) Is this right? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Joe TODO.
I'm not sure but I don't see why not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed - cheers
// cutlass::reference::device::TensorScaleBiasGemm< | ||
// ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, | ||
// ElementCompute, typename B2bGemm::LayoutC | ||
// > ( | ||
// {M, N, K}, | ||
// ref_Z.at(i), | ||
// ref_ref_D0.at(i), | ||
// alpha0, | ||
// ref_Scale0.at(i), | ||
// ref_Bias0.at(i) | ||
// ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dead code
|
||
syclcompat::wait(); | ||
|
||
// TODO(joe): Add the right epilogue here (PerRowBias) for testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
joe TODO
block_C.reset(M * N * L); | ||
block_D.reset(M * N * L); | ||
block_ref_D.reset(M * N * L); | ||
bias.reset(N * L); //TODO(joe) or N? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
joe TODO
@@ -262,7 +262,7 @@ class GemmUniversal< | |||
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; | |||
epilogue( | |||
problem_shape_MNKL, | |||
subgroup_shape, | |||
subgroup_shape, // TODO(joe): Inconsistency here w/ blk_coord_mnkl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe change to TODO(codeplay):
?
test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp
Outdated
Show resolved
Hide resolved
test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp
Outdated
Show resolved
Hide resolved
test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp
Outdated
Show resolved
Hide resolved
// cutlass::reference::device::TensorScaleBiasGemm< | ||
// ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, | ||
// ElementCompute, typename B2bGemm::LayoutC | ||
// > ( | ||
// {M, N, K}, | ||
// ref_Z.at(i), | ||
// ref_ref_D0.at(i), | ||
// alpha0, | ||
// ref_Scale0.at(i), | ||
// ref_Bias0.at(i) | ||
// ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented code
|
||
syclcompat::wait(); | ||
|
||
// TODO(joe): Add the right epilogue here (PerRowBias) for testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO
block_C.reset(M * N * L); | ||
block_D.reset(M * N * L); | ||
block_ref_D.reset(M * N * L); | ||
bias.reset(N * L); //TODO(joe) or N? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO
@@ -341,7 +362,7 @@ class CollectiveEpilogue< | |||
auto acc_frag_mn = acc_frag(_, epi_m, epi_n); | |||
|
|||
CUTLASS_PRAGMA_UNROLL | |||
for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) { | |||
for (int epi_v = 0; epi_v < size(trD_frag); ++epi_v) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what is the point to the changes you are doing in this file? Are you changing to apply epilogue per subgroup instead of workgroup?
test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp
Outdated
Show resolved
Hide resolved
// Get the fusion callbacks | ||
constexpr bool RefSrc = true; | ||
auto residue_mn = make_coord(M, N); | ||
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ | ||
problem_shape_mnkl, | ||
TileShapeMNK{}, | ||
tile_coord_mnkl, | ||
SubgroupTileShape{}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think SubgroupTileShape is quite the same as tile_shape_mnk
, so maybe we should have a comment there. I guess sg_coord
is there to match that change too.
This is a very rough first attempt to get this working. It compiles and runs, but the validation fails.
Testing issues:
Impl issues:
EpilogueTile
shape (currently arbitrarily set insideCollectiveEpilogue
)Alignment
template arg which is more appropriate for XE