-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce an interface for SMEM encoding #5764
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a mismatch in what we discussed, the idea was to have a new trait/parent class representing any generic shared layout and the existing layout could inherit from this one.
The parent layout should not contain anything that doesn't generalize to any shared memory layout.
let methods = [ | ||
InterfaceMethod<"Get 'vec', one of properties of swizzling.", | ||
"unsigned", | ||
"getVec">, | ||
InterfaceMethod<"Get 'per phase', one of properties of swizzling.", | ||
"unsigned", | ||
"getPerPhase">, | ||
InterfaceMethod<"Get 'max phase', one of properties of swizzling.", | ||
"unsigned", | ||
"getMaxPhase">, | ||
InterfaceMethod<"Get the order of this SMEM encoding.", | ||
"SmallVector<unsigned>", | ||
"getOrder">, | ||
InterfaceMethod<"Get the shape of the CTAs per CGA.", | ||
"SmallVector<unsigned>", | ||
"getCTAsPerCGA">, | ||
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", | ||
"SmallVector<unsigned>", | ||
"getCTASplitNum">, | ||
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", | ||
"SmallVector<unsigned>", | ||
"getCTAOrder">, | ||
InterfaceMethod<"Get the size of the address alignment in bytes.", | ||
"int32_t", | ||
"getAlignment">, | ||
InterfaceMethod<"True if the SMEM layout is in the core-matrices format.", | ||
"bool", | ||
"hasLeadingOffset"> | ||
]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those methods don't belong to the parent class, they should only exist in the SharedEncoding class below since they don't generalize to any shared memory layout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But in practice, SharedEncoding
is used by multiple vendors. So I was assuming that these methods generalize to other shared layouts as well. In particular, they do apply to the new BW FP4 layout. Having them in the interface lets the rest of the code work generically with both SharedEncodingAttr
and the new encoding.
I thought about making the base interface less fat and adding an intermediate SMEM interface for NV-specific layout, which is subclassed by SharedEncodingAttr
and the new fp4 layout. But AMD also uses the same notion of swizzling, so I was not sure if that's a good idea.
Are you suggesting to duplicate these methods in SharedEncodingAttr
and the new fp4 layout, and do explicit cast to them whenever I need to do getVec()
, getHasLeadingOffset()
etc? I do rather have them in the common interface shared by the two NV encodings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To elaborate, here is an ideal interface / class hierarchy from NV's perspective:
SharedEncodingTrait (interface)
* getOrder
* getAlignment
|
SharedEncodingNvidia (interface)
* getVec, getMaxPhase etc
* hasLeadingOffset
* getCTALayout
/ \
SharedEncodingAttr SharedEncodingMMAv5Fp4PaddedAttr
But I'm now not sure how it would work with AMD if we don't do anything in this PR. Most likely, they would continue using SharedEncodingAttr
. Nothing changes functionality-wise, but it is odd in terms of the architecture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I have in mind is something like this:
SharedEncodingTrait (interface)
* getAlignment // even that probably should go away
* getCTALayout
/ \
SwizzleSharedEncodingAttr NvidiaMMASharedEncoding
* getVec, getMaxPhase etc. * swizzlingByteSize
* order. * transpose
* fp4 padding
the vec/phase and all have no meaning and the hasLeadingOffset
is a hack which gives a totally different meaning to the rest of the layout so we should move it out and keep only the original fields in the other layout as those describe swizzling.
If I have time I can try to send a PR tomorrow to illustrate that more clearly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I was missing the fact that vec / phase things are interchangeable with swizzlingByteSize
. So having both getVec
etc and hasLeadingOffset
in the interface is redundant.
In this PR, I didn't intend to make such big changes to the SMEM representation. Rather, I tried to make minimal, and mostly mechanical changes to avoid breaking exist code while allowing the introduction of the new fp4 encoding. Instead, if we want to clean up the existing definitions and introduce such an ideal structure now, it would take more time but I can attempt doing that. It still leaves the question of what to do with AMD (maybe intel as well?), though (ok they can use SwizzleSharedEncodingAttr
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this PR, I didn't intend to make such big changes to the SMEM representation. Rather, I tried to make minimal, and mostly mechanical changes to avoid breaking exist code while allowing the introduction of the new fp4 encoding.
First step could be to do that:
SharedEncodingTrait (interface)
* getAlignment // even that probably should go away
* getCTALayout
/
SwizzleSharedEncodingAttr
* getVec, getMaxPhase etc.
* order
Then second step will be to clean up the hasLeadingOffset which is really the harder part but I think it will help a lot has it has many things in common with the new layout you need. If I can free up some time I may be able to help with that.
It still leaves the question of what to do with AMD (maybe intel as well?), though (ok they can use SwizzleSharedEncodingAttr).
Yeah and for the new layout they need they can have a separate child in parallel.
let description = [{ | ||
The interface which can be inherited by various representations of an SMEM encoding. | ||
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, how is that the right description? That's the other way around, the parent trait is the common one and this one would keep the old definition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the usage of SharedEncodingAttr
by NV, AMD and Intel backends, I was assuming that the current definition of SharedEncodingAttr
is already general. So I simply moved its content into the parent trait.
To prepare for a follow-up PR which will introduce a new SMEM encoding representing the FP4 padded layout for Blackwell mixed-precision scaled dot. The layout is described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
We agreed to implement the new encoding as a separate class from the existing
SharedEncodingAttr
, rather than adding a new field or subclassing it. But we want the rest of the code to work generically with both SMEM encoding types, so I'm introducing an SMEM encoding interface which is inherited by both encodings.The PR is mostly mechanical changes:
DistributedEncoding
(use ofTrait
etc)SharedEncodingAttr
(getMaxPhase etc). Whether or nothasLeadingOffset
should be part of the interface is debatable, since this is NV-specific, but for now it is.isa<SharedEncodingAttr>
,cast<SharedEncodingAttr>
etc, andShardEncodingAttr
in function arguments to work withSharedEncodingTrait
instead.cc @ThomasRaoux @lezcano I hope I'm on the right track with this one