Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Weight Format & GEMM #18

Open
jeromeku opened this issue Apr 1, 2024 · 2 comments
Open

[QST] Weight Format & GEMM #18

jeromeku opened this issue Apr 1, 2024 · 2 comments

Comments

@jeromeku
Copy link

jeromeku commented Apr 1, 2024

@efrantar

Awesome work -- always enjoy your research on and implementation of efficient model inference.

I was hoping that you could shed some light on the logic of the packing step?

  • My understanding is that the individual int4 values need rearranged in order to use the fast unpack / convert functions from FasterTransformer.

  • Is the subsequent interleaving such that ldmatrix can be used on these packed values such that each thread holds the necessary values for mma.sync? Typically ldmatrix is used on fp16 / bf16 types, but in this case the weights are sub-byte types, hence the additional preprocessing required for efficient shared -> register copy. I know FasterTransformer has its own formatting logic as a workaround for this issue; I have yet to find a general solution to efficiently leveraging tensorcore primitives on sub-byte types without preprocessing weights to a custom format.

  • Theoretically, if I were to preprocess the weights of a non-GPTQ int4 model using the packing function -- i.e., any groupwise quantization method that yields 4b weights along with group scales and zeros -- would I be able to use the Marlin kernel on such model? If not, what changes would need to be made?

Many thanks!

@efrantar
Copy link
Member

efrantar commented Apr 2, 2024

Hi, Marlin only uses ldmatrix for the activations, as the weights are already preshuffled optimally for both dequantization and tensor core fragment layouts. You can find some more detailed description of how this format works here #12.

Marlin is completely independent of GPTQ, the model needs to be quantized symmetrically either with groupsize 128 or row-wise (how you produced this model doesn't matter to Marlin); then you can preprocess the weights and use Marlin kernels. Zero-points are currently not supported, the reasons for this are discussed here #5 (comment).

@jeromeku
Copy link
Author

jeromeku commented Apr 3, 2024

@efrantar

Thank you for taking the time to explain.

Have you looked into Cutlass, specifically the 3.x API that introduced the CuTe abstractions for tensor thread-value manipulation / mapping? Wondering if it could potentially help generalize / extend the handcrafted code in Marlin without sacrificing performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants