Skip to content

Commit

Permalink
[Codegen] Add support to emulate unsupported float type (#19943)
Browse files Browse the repository at this point in the history
This change enables the conversion of types such as f8E4M3FNUZ and
f8E5M2FNUZ (emulated via the existing APIs) into f32 operations. The
conversion logic is now tightly coupled with the executable target
attribute, so that it is applied only for the gfx942 target. This
removes the need for manual pass configuration to specify source types
and aligns the behaviour with the target’s capabilities. For any new
conversion, just populate the conversion target with source and target
types.

FIX:
#19921 (comment)

---------

Signed-off-by: Chi Liu<[email protected]>
Co-authored-by: AmosLewis <[email protected]>
  • Loading branch information
pashu123 and AmosLewis authored Feb 13, 2025
1 parent 8fab35c commit 0ff26a7
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 0 deletions.
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ iree_compiler_cc_library(
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"ConvertUnsupportedFloatArithPass.cpp",
"ConvolutionToIGEMM.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
Expand Down Expand Up @@ -203,6 +204,8 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AMDGPUDialect",
"@llvm-project//mlir:AMDGPUUtils",
"@llvm-project//mlir:AffineAnalysis",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineTransforms",
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ iree_cc_library(
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"ConvertUnsupportedFloatArithPass.cpp"
"ConvolutionToIGEMM.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
Expand Down Expand Up @@ -162,6 +163,8 @@ iree_cc_library(
::PassesIncGen
IREELinalgTransformDialect
LLVMSupport
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRAffineAnalysis
MLIRAffineDialect
MLIRAffineTransforms
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===--------------- ConvertUnsupportedFloatArithPass.cpp ----------------===//
//
// Emulate arith and vector floating point operations that use float types
// which are unspported on a target by inserting extf/truncf pairs around all
// such operations in order to produce arithmetic that can be performed while
// preserving the original rounding behavior.
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

#define DEBUG_TYPE "iree-convert-unsupported-float-arith"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

struct ConvertUnsupportedFloatArithPass final
: public impl::ConvertUnsupportedFloatArithPassBase<
ConvertUnsupportedFloatArithPass> {
void runOnOperation() override;
using Base::Base;
};

} // namespace

// Populates source and target conversion types based on the target
// architecture.
// TODO(pashu123): Refine the patterns based on the target arch.
static void populateSourceAndTargetType(MLIRContext *ctx, Operation *op,
SmallVectorImpl<Type> &sourceTypes,
Type &targetType) {
auto gpuAttr = getGPUTargetAttr(op);
if (!gpuAttr) {
return;
}
StringRef chipset = gpuAttr.getArch();
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
if (failed(maybeChipset)) {
LLVM_DEBUG(llvm::dbgs() << "Invalid chip name");
return;
}
// Add source and target conversion types for gfx94{*} series.
if (maybeChipset->majorVersion == 9 && maybeChipset->minorVersion == 4) {
sourceTypes.insert(sourceTypes.end(), {Float8E4M3FNUZType::get(ctx),
Float8E5M2FNUZType::get(ctx)});
targetType = Float32Type::get(ctx);
}
return;
}

void ConvertUnsupportedFloatArithPass::runOnOperation() {
MLIRContext *context = &getContext();
FunctionOpInterface funcOp = getOperation();
SmallVector<Type> sourceTypes;
Type targetType = nullptr;

populateSourceAndTargetType(context, funcOp, sourceTypes, targetType);

if (sourceTypes.empty() || !targetType) {
LLVM_DEBUG(llvm::dbgs() << "no source or target type specified, float "
"emulation will do nothing\n");
return;
}

if (llvm::is_contained(sourceTypes, targetType)) {
funcOp->emitError() << " target type cannot be an unsupported source type";
return signalPassFailure();
}

TypeConverter converter;
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
targetType);
RewritePatternSet patterns(context);
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
ConversionTarget target(*context);
arith::populateEmulateUnsupportedFloatsLegality(target, converter);

if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace mlir::iree_compiler
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def ConvertBf16ToUInt16BuffersPass :
let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16.";
}

def ConvertUnsupportedFloatArithPass
: InterfacePass<"iree-convert-unsupported-float-arith",
"mlir::FunctionOpInterface"> {
let summary = "Convert arith operations on unsupported(source types) float "
"types to the target type. Populates the source and target "
"based on the target architecture.";
}

def ConvertToDestinationPassingStylePass :
InterfacePass<"iree-codegen-convert-to-destination-passing-style", "mlir::FunctionOpInterface"> {
let summary =
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_lit_test_suite(
"convert_bf16_to_uint16_buffers.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
"convert_unsupported_float_arith.mlir",
"convolution_to_igemm.mlir",
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
"convert_unsupported_float_arith.mlir"
"convolution_to_igemm.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith))" %s | FileCheck %s

// CHECK-LABEL: func.func @negf_f8_unsupported
// CHECK-SAME: (%[[ARG0:.*]]: f8E4M3FNUZ) -> f8E4M3FNUZ
// CHECK: %[[EXT:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E4M3FNUZ to f32
// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{
%0 = arith.negf %arg0 : f8E4M3FNUZ
return %0 : f8E4M3FNUZ
}

// -----

// CHECK-LABEL: func.func @expand_f8(
// CHECK-SAME: %[[ARG0:.*]]: f8E5M2FNUZ
// CHECK: %[[EXT0:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E5M2FNUZ to f32
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f8E5M2FNUZ
// CHECK: %[[EXT1:.*]] = arith.extf %[[CST]] {{.*}} : f8E5M2FNUZ to f32
// CHECK: %[[SUM:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[SUM]] {{.*}} : f32 to f8E5M2FNUZ
// CHECK: return %[[TRUNC]] : f8E5M2FNUZ
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
func.func @expand_f8(%x: f8E5M2FNUZ) -> f8E5M2FNUZ attributes
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{
%c = arith.constant 1.0 : f8E5M2FNUZ
%y = arith.addf %x, %c : f8E5M2FNUZ
func.return %y : f8E5M2FNUZ
}

// -----

// CHECK-LABEL: func.func @dont_expand_cpu_target
// CHECK: %[[NEG:.*]] = arith.negf {{.*}} : f8E4M3FNUZ
func.func @dont_expand_cpu_target(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes
{ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>}{
%0 = arith.negf %arg0 : f8E4M3FNUZ
return %0 : f8E4M3FNUZ
}
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,

if (forROCDL) {
// convert to ROCDL.
funcPassManager.addPass(createConvertUnsupportedFloatArithPass);
modulePassManager.addPass(createConvertToROCDLPass());
modulePassManager.addNestedPass<LLVM::LLVMFuncOp>(
createROCDLAnnotateKernelForTranslationPass());
Expand Down

0 comments on commit 0ff26a7

Please sign in to comment.