Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinCombPerRowBias & EVT Changes #152

Open
wants to merge 15 commits into
base: sycl-develop
Choose a base branch
from

Conversation

joeatodd
Copy link
Collaborator

@joeatodd joeatodd commented Nov 8, 2024

This is a very rough first attempt to get this working. It compiles and runs, but the validation fails.

Testing issues:

  • The validation in the example doesn't do a per-row-bias so of course it fails...
  • Need to investigate using existing unit testing infrastructure

Impl issues:

  • Need to think more about the EpilogueTile shape (currently arbitrarily set inside CollectiveEpilogue)
  • Consider passing through an Alignment template arg which is more appropriate for XE

@joeatodd joeatodd force-pushed the jtodd/per-row-bias branch 2 times, most recently from 4720c36 to bdcb4bb Compare November 19, 2024 12:09
@joeatodd joeatodd added the incremental Incremental changes label Nov 19, 2024
aacostadiaz pushed a commit that referenced this pull request Dec 4, 2024
As a precursor to properly testing #152, I have adapted the gemm_testbed_3x.hpp infrastructure slightly for Xe and added a single test case for LinCombEltAct, as that's something we know should work.

---------

Co-authored-by: Finlay <[email protected]>
@@ -262,7 +262,7 @@ class GemmUniversal<
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
epilogue(
problem_shape_MNKL,
subgroup_shape,
subgroup_shape, // TODO(joe): Inconsistency here w/ blk_coord_mnkl
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aacostadiaz the subgroup shape is being passed in here instead of the WorkgroupTileShape. However, the coordinate blk_coord_mnkl is the work-group coordinate. This doesn't matter for this PR because I'm computing things inside xe_epilogue.hpp but it feels inconsistent. Potentially something to address as part of your ongoing work on the coordinates?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense. I can look into that as part of the fix for coordinates.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change to TODO(codeplay): ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

@joeatodd joeatodd marked this pull request as ready for review December 5, 2024 11:49
Copy link
Collaborator

@aacostadiaz aacostadiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!!!

I left some minor comments. I'm not sure if the verification in the example is checking the right thing. Could you please double check that?

examples/sycl/pvc/pvc_gemm_with_per_row_bias.cpp Outdated Show resolved Hide resolved
Comment on lines +201 to +203
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareEqual(
block_ref_D.get(), block_D.get(), block_D.size());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this checking the output with the epilogue or just the output of GEMM?

examples/sycl/pvc/pvc_gemm_with_per_row_bias.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_gemm_with_per_row_bias.cpp Outdated Show resolved Hide resolved
@@ -262,7 +262,7 @@ class GemmUniversal<
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
epilogue(
problem_shape_MNKL,
subgroup_shape,
subgroup_shape, // TODO(joe): Inconsistency here w/ blk_coord_mnkl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense. I can look into that as part of the fix for coordinates.

using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
using ElementBias = typename CollectiveEpilogue::ThreadEpilogueOp::ElementBias; //TODO(joe) Is this right?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Joe TODO.
I'm not sure but I don't see why not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed - cheers

Comment on lines +167 to +177
// cutlass::reference::device::TensorScaleBiasGemm<
// ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
// ElementCompute, typename B2bGemm::LayoutC
// > (
// {M, N, K},
// ref_Z.at(i),
// ref_ref_D0.at(i),
// alpha0,
// ref_Scale0.at(i),
// ref_Bias0.at(i)
// );
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code


syclcompat::wait();

// TODO(joe): Add the right epilogue here (PerRowBias) for testing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

joe TODO

block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
bias.reset(N * L); //TODO(joe) or N?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

joe TODO

examples/sycl/pvc/pvc_gemm_with_per_row_bias.cpp Outdated Show resolved Hide resolved
@@ -262,7 +262,7 @@ class GemmUniversal<
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
epilogue(
problem_shape_MNKL,
subgroup_shape,
subgroup_shape, // TODO(joe): Inconsistency here w/ blk_coord_mnkl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change to TODO(codeplay): ?

include/cutlass/epilogue/collective/xe_epilogue.hpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_gemm_with_per_row_bias.cpp Outdated Show resolved Hide resolved
Comment on lines +167 to +177
// cutlass::reference::device::TensorScaleBiasGemm<
// ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
// ElementCompute, typename B2bGemm::LayoutC
// > (
// {M, N, K},
// ref_Z.at(i),
// ref_ref_D0.at(i),
// alpha0,
// ref_Scale0.at(i),
// ref_Bias0.at(i)
// );
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented code


syclcompat::wait();

// TODO(joe): Add the right epilogue here (PerRowBias) for testing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
bias.reset(N * L); //TODO(joe) or N?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

examples/sycl/pvc/pvc_gemm_with_per_row_bias.cpp Outdated Show resolved Hide resolved
examples/sycl/pvc/pvc_gemm_with_per_row_bias.cpp Outdated Show resolved Hide resolved
include/cutlass/epilogue/collective/xe_epilogue.hpp Outdated Show resolved Hide resolved
@@ -341,7 +362,7 @@ class CollectiveEpilogue<
auto acc_frag_mn = acc_frag(_, epi_m, epi_n);

CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) {
for (int epi_v = 0; epi_v < size(trD_frag); ++epi_v) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what is the point to the changes you are doing in this file? Are you changing to apply epilogue per subgroup instead of workgroup?

// Get the fusion callbacks
constexpr bool RefSrc = true;
auto residue_mn = make_coord(M, N);
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
problem_shape_mnkl,
TileShapeMNK{},
tile_coord_mnkl,
SubgroupTileShape{},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think SubgroupTileShape is quite the same as tile_shape_mnk, so maybe we should have a comment there. I guess sg_coord is there to match that change too.

@joeatodd joeatodd removed the incremental Incremental changes label Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants