Skip to content

Commit

Permalink
Add conversion based on the target arch.
Browse files Browse the repository at this point in the history
  • Loading branch information
pashu123 committed Feb 13, 2025
1 parent 6c5bb93 commit 1966456
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 73 deletions.
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AMDGPUDialect",
"@llvm-project//mlir:AMDGPUUtils",
"@llvm-project//mlir:AffineAnalysis",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineTransforms",
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ iree_cc_library(
::PassesIncGen
IREELinalgTransformDialect
LLVMSupport
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRAffineAnalysis
MLIRAffineDialect
MLIRAffineTransforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

#define DEBUG_TYPE "iree-convert-unsupported-float-arith"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS
Expand All @@ -35,60 +39,48 @@ struct ConvertUnsupportedFloatArithPass final
};

} // namespace
//

static LogicalResult ParseFromOption(MLIRContext *ctx,
ArrayRef<std::string> sourceTypeStrs,
std::string targetTypeStr,
SmallVectorImpl<Type> &sourceTypes,
Type &targetType) {

std::optional<FloatType> maybeTargetType =
arith::parseFloatType(ctx, targetTypeStr);
if (!maybeTargetType) {
emitError(UnknownLoc::get(ctx), "could not map target type '" +
targetTypeStr +
"' to a known floating-point type");
return failure();
// Populates source and target conversion types based on the target
// architecture.
// TODO(pashu123): Refine the patterns based on the target arch.
static void populateSourceAndTargetType(MLIRContext *ctx, Operation *op,
SmallVectorImpl<Type> &sourceTypes,
Type &targetType) {
auto gpuAttr = getGPUTargetAttr(op);
if (!gpuAttr) {
return;
}
targetType = *maybeTargetType;
for (StringRef sourceTypeStr : sourceTypeStrs) {
std::optional<FloatType> maybeSourceType =
arith::parseFloatType(ctx, sourceTypeStr);
if (!maybeSourceType) {
emitError(UnknownLoc::get(ctx), "could not map source type '" +
sourceTypeStr +
"' to a known floating-point type");
return failure();
}
sourceTypes.push_back(*maybeSourceType);
StringRef chipset = gpuAttr.getArch();
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
if (failed(maybeChipset)) {
LLVM_DEBUG(llvm::dbgs() << "Invalid chip name");
return;
}

return success();
// Add source and target conversion types for gfx94{*} series.
if (maybeChipset->majorVersion == 9 && maybeChipset->minorVersion == 4) {
sourceTypes.insert(sourceTypes.end(), {Float8E4M3FNUZType::get(ctx),
Float8E5M2FNUZType::get(ctx)});
targetType = Float32Type::get(ctx);
}
return;
}

void ConvertUnsupportedFloatArithPass::runOnOperation() {

MLIRContext *context = &getContext();
Operation *op = getOperation();
FunctionOpInterface funcOp = getOperation();
SmallVector<Type> sourceTypes;
Type targetType;
Type targetType = nullptr;

if (failed(ParseFromOption(context, sourceTypeStrs, targetTypeStr,
sourceTypes, targetType))) {
return signalPassFailure();
}
populateSourceAndTargetType(context, funcOp, sourceTypes, targetType);

if (sourceTypes.empty()) {
(void)emitOptionalWarning(
std::nullopt,
"no source types specified, float emulation will do nothing");
return signalPassFailure();
if (sourceTypes.empty() || !targetType) {
LLVM_DEBUG(llvm::dbgs() << "no source or target type specified, float "
"emulation will do nothing\n");
return;
}

if (llvm::is_contained(sourceTypes, targetType)) {
emitError(UnknownLoc::get(context),
"target type cannot be an unsupported source type");
funcOp->emitError() << " target type cannot be an unsupported source type";
return signalPassFailure();
}

Expand All @@ -100,8 +92,9 @@ void ConvertUnsupportedFloatArithPass::runOnOperation() {
ConversionTarget target(*context);
arith::populateEmulateUnsupportedFloatsLegality(target, converter);

if (failed(applyPartialConversion(op, target, std::move(patterns))))
if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace mlir::iree_compiler
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler {
Expand Down
14 changes: 6 additions & 8 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,12 @@ def ConvertBf16ToUInt16BuffersPass :
let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16.";
}

def ConvertUnsupportedFloatArithPass : InterfacePass<"iree-convert-unsupported-float-arith", "mlir::FunctionOpInterface"> {
let summary = "Convert arith operations on unsupported(source types) float types to the target type.";
let options = [
ListOption<"sourceTypeStrs", "source-types", "std::string",
"MLIR types without arithmetic support on a given target">,
Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
"MLIR type to convert the unsupported source types to">,
];
def ConvertUnsupportedFloatArithPass
: InterfacePass<"iree-convert-unsupported-float-arith",
"mlir::FunctionOpInterface"> {
let summary = "Convert arith operations on unsupported(source types) float "
"types to the target type. Populates the source and target "
"based on the target architecture.";
}

def ConvertToDestinationPassingStylePass :
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith{source-types=f8E4M3FNUZ,bf16 target-type=f32}))" %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith))" %s | FileCheck %s

// CHECK-LABEL: func.func @negf_f8_unsupported
// CHECK-SAME: (%[[ARG0:.*]]: f8E4M3FNUZ) -> f8E4M3FNUZ
// CHECK: %[[EXT:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E4M3FNUZ to f32
// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ
func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ {
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{
%0 = arith.negf %arg0 : f8E4M3FNUZ
return %0 : f8E4M3FNUZ
}

// -----

func.func @bf16_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @bf16_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {{.*}} : bf16 to f32
// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] {{.*}} : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] {{.*}} : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
func.return %y : bf16
// CHECK-LABEL: func.func @expand_f8(
// CHECK-SAME: %[[ARG0:.*]]: f8E5M2FNUZ
// CHECK: %[[EXT0:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E5M2FNUZ to f32
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f8E5M2FNUZ
// CHECK: %[[EXT1:.*]] = arith.extf %[[CST]] {{.*}} : f8E5M2FNUZ to f32
// CHECK: %[[SUM:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[SUM]] {{.*}} : f32 to f8E5M2FNUZ
// CHECK: return %[[TRUNC]] : f8E5M2FNUZ
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
func.func @expand_f8(%x: f8E5M2FNUZ) -> f8E5M2FNUZ attributes
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{
%c = arith.constant 1.0 : f8E5M2FNUZ
%y = arith.addf %x, %c : f8E5M2FNUZ
func.return %y : f8E5M2FNUZ
}

// -----

// CHECK-LABEL: func.func @dont_expand_cpu_target
// CHECK: %[[NEG:.*]] = arith.negf {{.*}} : f8E4M3FNUZ
func.func @dont_expand_cpu_target(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes
{ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>}{
%0 = arith.negf %arg0 : f8E4M3FNUZ
return %0 : f8E4M3FNUZ
}
9 changes: 1 addition & 8 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1126,14 +1126,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,

if (forROCDL) {
// convert to ROCDL.
funcPassManager.addPass([&]() {
ConvertUnsupportedFloatArithPassOptions options;
// Convert arith operation with the given `source types` to `target`
// type.
options.sourceTypeStrs = {"f8E4M3FNUZ", "f8E5M2FNUZ"};
options.targetTypeStr = "f32";
return createConvertUnsupportedFloatArithPass(options);
});
funcPassManager.addPass(createConvertUnsupportedFloatArithPass);
modulePassManager.addPass(createConvertToROCDLPass());
modulePassManager.addNestedPass<LLVM::LLVMFuncOp>(
createROCDLAnnotateKernelForTranslationPass());
Expand Down

0 comments on commit 1966456

Please sign in to comment.