Skip to content

Commit

Permalink
PR #58720: FP8 GEMMs in XLA
Browse files Browse the repository at this point in the history
Imported from GitHub PR tensorflow/tensorflow#58720

Enables scaled GEMMs based on `F8E4M3FN` and `F8E5M2` [FP8 data types](https://arxiv.org/abs/2209.05433). The pattern described by steps 1 through 6 in [RFC #22](openxla/xla#22) is rewritten into a Custom Call of the form

(A, B, a_scale, b_scale, d_scale) -> (D, d_amax),

where A, B and D are FP8 matrices and a_scale, b_scale and d_scale are their respective scaling factors. The scalar d_amax gives the maximum of the absolute values in D before rescaling and casting to FP8 and can be used in the calculation of new scaling factors.
Copybara import of the project:

--
f2eb35a9efcaaffdbb7314f99521357840bd49d8 by Philipp Hack <[email protected]>:

Support for FP8 GEMMs in XLA.

--
0afd695b3840417fdb1c00987c8c5e980be0de33 by Philipp Hack <[email protected]>:

Support for FP8 GEMMs in XLA.

--
5aba0882bc624215613c77d73dd23ec3b1d8b0d9 by Philipp Hack <[email protected]>:

Support for FP8 GEMMs in XLA.

--
8d18d22d61b1b440421fc3dd402acdaaf27519b3 by Philipp Hack <[email protected]>:

Support for FP8 GEMMs in XLA.

--
7759e0a5d041c26c632d4e433d5f544e0194ea40 by Philipp Hack <[email protected]>:

Support for FP8 GEMMs in XLA.

Merging this change closes #58720

PiperOrigin-RevId: 495806551
  • Loading branch information
philipphack authored and TensorFlow MLIR Team committed Dec 16, 2022
1 parent eea9ce1 commit 21ad757
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions lhlo_gpu/IR/lhlo_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandS
I64Attr:$algorithm);
}

def LHLOGPU_CublasLtMatmulF8Op : LHLOGPU_Op<"cublas.lt.matmul.f8"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$a,
Arg<LHLO_Buffer, "", [MemRead]>:$b,
Arg<LHLO_Buffer, "", [MemRead]>:$c,
Arg<LHLO_Buffer, "", [MemRead]>:$a_scale,
Arg<LHLO_Buffer, "", [MemRead]>:$b_scale,
Arg<LHLO_Buffer, "", [MemRead]>:$c_scale,
Arg<LHLO_Buffer, "", [MemRead]>:$d_scale,
Arg<LHLO_Buffer, "", [MemWrite]>:$d,
Arg<Optional<LHLO_Buffer>, "", [MemWrite]>:$d_amax,
MHLO_DotDimensionNumbers:$dot_dimension_numbers,
MHLO_PrecisionConfigAttr:$precision_config,
F64Attr:$alpha_real,
F64Attr:$alpha_imag,
F64Attr:$beta,
CublasLtMatmulEpilogueAttr:$epilogue,
I64Attr:$algorithm);
}

def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Expand Down

0 comments on commit 21ad757

Please sign in to comment.