-
Notifications
You must be signed in to change notification settings - Fork 663
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Codegen] Add support to emulate unsupported float type (#19943)
This change enables the conversion of types such as f8E4M3FNUZ and f8E5M2FNUZ (emulated via the existing APIs) into f32 operations. The conversion logic is now tightly coupled with the executable target attribute, so that it is applied only for the gfx942 target. This removes the need for manual pass configuration to specify source types and aligns the behaviour with the target’s capabilities. For any new conversion, just populate the conversion target with source and target types. FIX: #19921 (comment) --------- Signed-off-by: Chi Liu<[email protected]> Co-authored-by: AmosLewis <[email protected]>
- Loading branch information
Showing
8 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
100 changes: 100 additions & 0 deletions
100
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright 2025 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===--------------- ConvertUnsupportedFloatArithPass.cpp ----------------===// | ||
// | ||
// Emulate arith and vector floating point operations that use float types | ||
// which are unspported on a target by inserting extf/truncf pairs around all | ||
// such operations in order to produce arithmetic that can be performed while | ||
// preserving the original rounding behavior. | ||
// | ||
//===---------------------------------------------------------------------===// | ||
|
||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Codegen/Common/Transforms.h" | ||
#include "iree/compiler/Codegen/Utils/GPUUtils.h" | ||
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
|
||
#define DEBUG_TYPE "iree-convert-unsupported-float-arith" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
struct ConvertUnsupportedFloatArithPass final | ||
: public impl::ConvertUnsupportedFloatArithPassBase< | ||
ConvertUnsupportedFloatArithPass> { | ||
void runOnOperation() override; | ||
using Base::Base; | ||
}; | ||
|
||
} // namespace | ||
|
||
// Populates source and target conversion types based on the target | ||
// architecture. | ||
// TODO(pashu123): Refine the patterns based on the target arch. | ||
static void populateSourceAndTargetType(MLIRContext *ctx, Operation *op, | ||
SmallVectorImpl<Type> &sourceTypes, | ||
Type &targetType) { | ||
auto gpuAttr = getGPUTargetAttr(op); | ||
if (!gpuAttr) { | ||
return; | ||
} | ||
StringRef chipset = gpuAttr.getArch(); | ||
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); | ||
if (failed(maybeChipset)) { | ||
LLVM_DEBUG(llvm::dbgs() << "Invalid chip name"); | ||
return; | ||
} | ||
// Add source and target conversion types for gfx94{*} series. | ||
if (maybeChipset->majorVersion == 9 && maybeChipset->minorVersion == 4) { | ||
sourceTypes.insert(sourceTypes.end(), {Float8E4M3FNUZType::get(ctx), | ||
Float8E5M2FNUZType::get(ctx)}); | ||
targetType = Float32Type::get(ctx); | ||
} | ||
return; | ||
} | ||
|
||
void ConvertUnsupportedFloatArithPass::runOnOperation() { | ||
MLIRContext *context = &getContext(); | ||
FunctionOpInterface funcOp = getOperation(); | ||
SmallVector<Type> sourceTypes; | ||
Type targetType = nullptr; | ||
|
||
populateSourceAndTargetType(context, funcOp, sourceTypes, targetType); | ||
|
||
if (sourceTypes.empty() || !targetType) { | ||
LLVM_DEBUG(llvm::dbgs() << "no source or target type specified, float " | ||
"emulation will do nothing\n"); | ||
return; | ||
} | ||
|
||
if (llvm::is_contained(sourceTypes, targetType)) { | ||
funcOp->emitError() << " target type cannot be an unsupported source type"; | ||
return signalPassFailure(); | ||
} | ||
|
||
TypeConverter converter; | ||
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes, | ||
targetType); | ||
RewritePatternSet patterns(context); | ||
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter); | ||
ConversionTarget target(*context); | ||
arith::populateEmulateUnsupportedFloatsLegality(target, converter); | ||
|
||
if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 42 additions & 0 deletions
42
compiler/src/iree/compiler/Codegen/Common/test/convert_unsupported_float_arith.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith))" %s | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @negf_f8_unsupported | ||
// CHECK-SAME: (%[[ARG0:.*]]: f8E4M3FNUZ) -> f8E4M3FNUZ | ||
// CHECK: %[[EXT:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E4M3FNUZ to f32 | ||
// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32 | ||
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ | ||
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ | ||
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}> | ||
func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes | ||
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{ | ||
%0 = arith.negf %arg0 : f8E4M3FNUZ | ||
return %0 : f8E4M3FNUZ | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @expand_f8( | ||
// CHECK-SAME: %[[ARG0:.*]]: f8E5M2FNUZ | ||
// CHECK: %[[EXT0:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E5M2FNUZ to f32 | ||
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f8E5M2FNUZ | ||
// CHECK: %[[EXT1:.*]] = arith.extf %[[CST]] {{.*}} : f8E5M2FNUZ to f32 | ||
// CHECK: %[[SUM:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32 | ||
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[SUM]] {{.*}} : f32 to f8E5M2FNUZ | ||
// CHECK: return %[[TRUNC]] : f8E5M2FNUZ | ||
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}> | ||
func.func @expand_f8(%x: f8E5M2FNUZ) -> f8E5M2FNUZ attributes | ||
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{ | ||
%c = arith.constant 1.0 : f8E5M2FNUZ | ||
%y = arith.addf %x, %c : f8E5M2FNUZ | ||
func.return %y : f8E5M2FNUZ | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @dont_expand_cpu_target | ||
// CHECK: %[[NEG:.*]] = arith.negf {{.*}} : f8E4M3FNUZ | ||
func.func @dont_expand_cpu_target(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes | ||
{ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>}{ | ||
%0 = arith.negf %arg0 : f8E4M3FNUZ | ||
return %0 : f8E4M3FNUZ | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters