Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce an interface for SMEM encoding #5764

Closed
wants to merge 15 commits into from

Conversation

masahi
Copy link
Collaborator

@masahi masahi commented Jan 30, 2025

To prepare for a follow-up PR which will introduce a new SMEM encoding representing the FP4 padded layout for Blackwell mixed-precision scaled dot. The layout is described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory

We agreed to implement the new encoding as a separate class from the existing SharedEncodingAttr, rather than adding a new field or subclassing it. But we want the rest of the code to work generically with both SMEM encoding types, so I'm introducing an SMEM encoding interface which is inherited by both encodings.

The PR is mostly mechanical changes:

  • Follows the design of DistributedEncoding (use of Trait etc)
  • The interface consists of methods which used to be an accessor method on SharedEncodingAttr (getMaxPhase etc). Whether or not hasLeadingOffset should be part of the interface is debatable, since this is NV-specific, but for now it is.
  • Change isa<SharedEncodingAttr>, cast<SharedEncodingAttr> etc, and ShardEncodingAttr in function arguments to work with SharedEncodingTrait instead.
  • AMD code is untouched. But given the discussion in [WIP] Support shared encoding defined with linear layout #5720 I'm curious if this could be useful there too cc @binarman

cc @ThomasRaoux @lezcano I hope I'm on the right track with this one

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a mismatch in what we discussed, the idea was to have a new trait/parent class representing any generic shared layout and the existing layout could inherit from this one.
The parent layout should not contain anything that doesn't generalize to any shared memory layout.

Comment on lines +237 to +265
let methods = [
InterfaceMethod<"Get 'vec', one of properties of swizzling.",
"unsigned",
"getVec">,
InterfaceMethod<"Get 'per phase', one of properties of swizzling.",
"unsigned",
"getPerPhase">,
InterfaceMethod<"Get 'max phase', one of properties of swizzling.",
"unsigned",
"getMaxPhase">,
InterfaceMethod<"Get the order of this SMEM encoding.",
"SmallVector<unsigned>",
"getOrder">,
InterfaceMethod<"Get the shape of the CTAs per CGA.",
"SmallVector<unsigned>",
"getCTAsPerCGA">,
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
"SmallVector<unsigned>",
"getCTASplitNum">,
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
"SmallVector<unsigned>",
"getCTAOrder">,
InterfaceMethod<"Get the size of the address alignment in bytes.",
"int32_t",
"getAlignment">,
InterfaceMethod<"True if the SMEM layout is in the core-matrices format.",
"bool",
"hasLeadingOffset">
];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those methods don't belong to the parent class, they should only exist in the SharedEncoding class below since they don't generalize to any shared memory layout

Copy link
Collaborator Author

@masahi masahi Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in practice, SharedEncoding is used by multiple vendors. So I was assuming that these methods generalize to other shared layouts as well. In particular, they do apply to the new BW FP4 layout. Having them in the interface lets the rest of the code work generically with both SharedEncodingAttr and the new encoding.

I thought about making the base interface less fat and adding an intermediate SMEM interface for NV-specific layout, which is subclassed by SharedEncodingAttr and the new fp4 layout. But AMD also uses the same notion of swizzling, so I was not sure if that's a good idea.

Are you suggesting to duplicate these methods in SharedEncodingAttr and the new fp4 layout, and do explicit cast to them whenever I need to do getVec(), getHasLeadingOffset() etc? I do rather have them in the common interface shared by the two NV encodings.

Copy link
Collaborator Author

@masahi masahi Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To elaborate, here is an ideal interface / class hierarchy from NV's perspective:

           SharedEncodingTrait (interface)
                * getOrder
                * getAlignment
                     |
           SharedEncodingNvidia (interface)
                * getVec, getMaxPhase etc
                * hasLeadingOffset
                * getCTALayout
              /               \
SharedEncodingAttr         SharedEncodingMMAv5Fp4PaddedAttr

But I'm now not sure how it would work with AMD if we don't do anything in this PR. Most likely, they would continue using SharedEncodingAttr. Nothing changes functionality-wise, but it is odd in terms of the architecture.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I have in mind is something like this:

           SharedEncodingTrait (interface)
                * getAlignment // even that probably should go away
                * getCTALayout
              /                \
SwizzleSharedEncodingAttr         NvidiaMMASharedEncoding
 * getVec, getMaxPhase etc.        * swizzlingByteSize
 * order.                                     * transpose
                                                  * fp4 padding

the vec/phase and all have no meaning and the hasLeadingOffset is a hack which gives a totally different meaning to the rest of the layout so we should move it out and keep only the original fields in the other layout as those describe swizzling.

If I have time I can try to send a PR tomorrow to illustrate that more clearly.

Copy link
Collaborator Author

@masahi masahi Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I was missing the fact that vec / phase things are interchangeable with swizzlingByteSize. So having both getVec etc and hasLeadingOffset in the interface is redundant.

In this PR, I didn't intend to make such big changes to the SMEM representation. Rather, I tried to make minimal, and mostly mechanical changes to avoid breaking exist code while allowing the introduction of the new fp4 encoding. Instead, if we want to clean up the existing definitions and introduce such an ideal structure now, it would take more time but I can attempt doing that. It still leaves the question of what to do with AMD (maybe intel as well?), though (ok they can use SwizzleSharedEncodingAttr).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, I didn't intend to make such big changes to the SMEM representation. Rather, I tried to make minimal, and mostly mechanical changes to avoid breaking exist code while allowing the introduction of the new fp4 encoding.

First step could be to do that:

           SharedEncodingTrait (interface)
                * getAlignment // even that probably should go away
                * getCTALayout
              /             
SwizzleSharedEncodingAttr 
 * getVec, getMaxPhase etc. 
 * order

Then second step will be to clean up the hasLeadingOffset which is really the harder part but I think it will help a lot has it has many things in common with the new layout you need. If I can free up some time I may be able to help with that.

It still leaves the question of what to do with AMD (maybe intel as well?), though (ok they can use SwizzleSharedEncodingAttr).

Yeah and for the new layout they need they can have a separate child in parallel.

Comment on lines +272 to +274
let description = [{
The interface which can be inherited by various representations of an SMEM encoding.
}];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, how is that the right description? That's the other way around, the parent trait is the common one and this one would keep the old definition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the usage of SharedEncodingAttr by NV, AMD and Intel backends, I was assuming that the current definition of SharedEncodingAttr is already general. So I simply moved its content into the parent trait.

@masahi
Copy link
Collaborator Author

masahi commented Feb 1, 2025

#5786

@masahi masahi closed this Feb 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants