Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Imported from GitHub PR tensorflow/tensorflow#58720 Enables scaled GEMMs based on `F8E4M3FN` and `F8E5M2` [FP8 data types](https://arxiv.org/abs/2209.05433). The pattern described by steps 1 through 6 in [RFC #22](openxla/xla#22) is rewritten into a Custom Call of the form (A, B, a_scale, b_scale, d_scale) -> (D, d_amax), where A, B and D are FP8 matrices and a_scale, b_scale and d_scale are their respective scaling factors. The scalar d_amax gives the maximum of the absolute values in D before rescaling and casting to FP8 and can be used in the calculation of new scaling factors. Copybara import of the project: -- f2eb35a9efcaaffdbb7314f99521357840bd49d8 by Philipp Hack <[email protected]>: Support for FP8 GEMMs in XLA. -- 0afd695b3840417fdb1c00987c8c5e980be0de33 by Philipp Hack <[email protected]>: Support for FP8 GEMMs in XLA. -- 5aba0882bc624215613c77d73dd23ec3b1d8b0d9 by Philipp Hack <[email protected]>: Support for FP8 GEMMs in XLA. -- 8d18d22d61b1b440421fc3dd402acdaaf27519b3 by Philipp Hack <[email protected]>: Support for FP8 GEMMs in XLA. -- 7759e0a5d041c26c632d4e433d5f544e0194ea40 by Philipp Hack <[email protected]>: Support for FP8 GEMMs in XLA. Merging this change closes #58720 PiperOrigin-RevId: 495806551
- Loading branch information