Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] Add support to emulate unsupported float type #19943

Merged
merged 3 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ iree_compiler_cc_library(
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"ConvertUnsupportedFloatArithPass.cpp",
"ConvolutionToIGEMM.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
Expand Down Expand Up @@ -201,6 +202,8 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AMDGPUDialect",
"@llvm-project//mlir:AMDGPUUtils",
"@llvm-project//mlir:AffineAnalysis",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineTransforms",
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ iree_cc_library(
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"ConvertUnsupportedFloatArithPass.cpp"
"ConvolutionToIGEMM.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
Expand Down Expand Up @@ -160,6 +161,8 @@ iree_cc_library(
::PassesIncGen
IREELinalgTransformDialect
LLVMSupport
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRAffineAnalysis
MLIRAffineDialect
MLIRAffineTransforms
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===--------------- ConvertUnsupportedFloatArithPass.cpp ----------------===//
//
// Emulate arith and vector floating point operations that use float types
// which are unspported on a target by inserting extf/truncf pairs around all
// such operations in order to produce arithmetic that can be performed while
// preserving the original rounding behavior.
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

#define DEBUG_TYPE "iree-convert-unsupported-float-arith"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

struct ConvertUnsupportedFloatArithPass final
: public impl::ConvertUnsupportedFloatArithPassBase<
ConvertUnsupportedFloatArithPass> {
void runOnOperation() override;
using Base::Base;
};

} // namespace

// Populates source and target conversion types based on the target
// architecture.
// TODO(pashu123): Refine the patterns based on the target arch.
static void populateSourceAndTargetType(MLIRContext *ctx, Operation *op,
SmallVectorImpl<Type> &sourceTypes,
Type &targetType) {
auto gpuAttr = getGPUTargetAttr(op);
if (!gpuAttr) {
return;
}
StringRef chipset = gpuAttr.getArch();
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
if (failed(maybeChipset)) {
LLVM_DEBUG(llvm::dbgs() << "Invalid chip name");
return;
}
// Add source and target conversion types for gfx94{*} series.
if (maybeChipset->majorVersion == 9 && maybeChipset->minorVersion == 4) {
sourceTypes.insert(sourceTypes.end(), {Float8E4M3FNUZType::get(ctx),
Float8E5M2FNUZType::get(ctx)});
targetType = Float32Type::get(ctx);
}
return;
}

void ConvertUnsupportedFloatArithPass::runOnOperation() {
MLIRContext *context = &getContext();
FunctionOpInterface funcOp = getOperation();
SmallVector<Type> sourceTypes;
Type targetType = nullptr;

populateSourceAndTargetType(context, funcOp, sourceTypes, targetType);

if (sourceTypes.empty() || !targetType) {
LLVM_DEBUG(llvm::dbgs() << "no source or target type specified, float "
"emulation will do nothing\n");
return;
}

if (llvm::is_contained(sourceTypes, targetType)) {
funcOp->emitError() << " target type cannot be an unsupported source type";
return signalPassFailure();
}

TypeConverter converter;
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
targetType);
RewritePatternSet patterns(context);
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
ConversionTarget target(*context);
arith::populateEmulateUnsupportedFloatsLegality(target, converter);

if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace mlir::iree_compiler
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def ConvertBf16ToUInt16BuffersPass :
let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16.";
}

def ConvertUnsupportedFloatArithPass
: InterfacePass<"iree-convert-unsupported-float-arith",
"mlir::FunctionOpInterface"> {
let summary = "Convert arith operations on unsupported(source types) float "
"types to the target type. Populates the source and target "
"based on the target architecture.";
}

def ConvertToDestinationPassingStylePass :
InterfacePass<"iree-codegen-convert-to-destination-passing-style", "mlir::FunctionOpInterface"> {
let summary =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_lit_test_suite(
"convert_bf16_to_uint16_buffers.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
"convert_unsupported_float_arith.mlir",
"convolution_to_igemm.mlir",
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
"convert_unsupported_float_arith.mlir"
"convolution_to_igemm.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith))" %s | FileCheck %s

// CHECK-LABEL: func.func @negf_f8_unsupported
// CHECK-SAME: (%[[ARG0:.*]]: f8E4M3FNUZ) -> f8E4M3FNUZ
// CHECK: %[[EXT:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E4M3FNUZ to f32
// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{
%0 = arith.negf %arg0 : f8E4M3FNUZ
return %0 : f8E4M3FNUZ
}

// -----

// CHECK-LABEL: func.func @expand_f8(
// CHECK-SAME: %[[ARG0:.*]]: f8E5M2FNUZ
// CHECK: %[[EXT0:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E5M2FNUZ to f32
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f8E5M2FNUZ
// CHECK: %[[EXT1:.*]] = arith.extf %[[CST]] {{.*}} : f8E5M2FNUZ to f32
// CHECK: %[[SUM:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[SUM]] {{.*}} : f32 to f8E5M2FNUZ
// CHECK: return %[[TRUNC]] : f8E5M2FNUZ
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
func.func @expand_f8(%x: f8E5M2FNUZ) -> f8E5M2FNUZ attributes
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{
%c = arith.constant 1.0 : f8E5M2FNUZ
%y = arith.addf %x, %c : f8E5M2FNUZ
func.return %y : f8E5M2FNUZ
}

// -----

// CHECK-LABEL: func.func @dont_expand_cpu_target
// CHECK: %[[NEG:.*]] = arith.negf {{.*}} : f8E4M3FNUZ
func.func @dont_expand_cpu_target(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes
{ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>}{
%0 = arith.negf %arg0 : f8E4M3FNUZ
return %0 : f8E4M3FNUZ
}
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,

if (forROCDL) {
// convert to ROCDL.
funcPassManager.addPass(createConvertUnsupportedFloatArithPass);
modulePassManager.addPass(createConvertToROCDLPass());
modulePassManager.addNestedPass<LLVM::LLVMFuncOp>(
createROCDLAnnotateKernelForTranslationPass());
Expand Down
Loading