diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp index f1364654dde6..87f03e7d3f17 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp @@ -15,11 +15,15 @@ #include "iree/compiler/Codegen/Common/Passes.h" #include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#define DEBUG_TYPE "iree-convert-unsupported-float-arith" + namespace mlir::iree_compiler { #define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS @@ -35,60 +39,57 @@ struct ConvertUnsupportedFloatArithPass final }; } // namespace -// -static LogicalResult ParseFromOption(MLIRContext *ctx, - ArrayRef sourceTypeStrs, - std::string targetTypeStr, - SmallVectorImpl &sourceTypes, - Type &targetType) { - - std::optional maybeTargetType = - arith::parseFloatType(ctx, targetTypeStr); - if (!maybeTargetType) { - emitError(UnknownLoc::get(ctx), "could not map target type '" + - targetTypeStr + - "' to a known floating-point type"); +// Populates source and target conversion types based on the target +// architecture. +// TODO(pashu123): Refine the patterns based on the target arch. +static LogicalResult +populateSourceAndTargetType(MLIRContext *ctx, Operation *op, + SmallVectorImpl &sourceTypes, + Type &targetType) { + + auto gpuAttr = getGPUTargetAttr(op); + if (!gpuAttr) + return failure(); + + StringRef chipset = gpuAttr.getArch(); + FailureOr maybeChipset = amdgpu::Chipset::parse(chipset); + if (failed(maybeChipset)) { + op->emitError() << "Invalid chipset name: " << chipset; return failure(); } - targetType = *maybeTargetType; - for (StringRef sourceTypeStr : sourceTypeStrs) { - std::optional maybeSourceType = - arith::parseFloatType(ctx, sourceTypeStr); - if (!maybeSourceType) { - emitError(UnknownLoc::get(ctx), "could not map source type '" + - sourceTypeStr + - "' to a known floating-point type"); - return failure(); - } - sourceTypes.push_back(*maybeSourceType); + // Add source and target conversion types for gfx94{*} series. + if (maybeChipset->majorVersion == 9 && maybeChipset->minorVersion == 4) { + sourceTypes.insert(sourceTypes.end(), {Float8E4M3FNUZType::get(ctx), + Float8E5M2FNUZType::get(ctx)}); + targetType = Float32Type::get(ctx); + return success(); } - - return success(); + return failure(); } void ConvertUnsupportedFloatArithPass::runOnOperation() { - MLIRContext *context = &getContext(); - Operation *op = getOperation(); + FunctionOpInterface funcOp = getOperation(); SmallVector sourceTypes; Type targetType; - if (failed(ParseFromOption(context, sourceTypeStrs, targetTypeStr, - sourceTypes, targetType))) { - return signalPassFailure(); + if (failed(populateSourceAndTargetType(context, funcOp, sourceTypes, + targetType))) { + LLVM_DEBUG(llvm::dbgs() + << "failed while populating source and target types\n"); + return; } if (sourceTypes.empty()) { - (void)emitOptionalWarning( - std::nullopt, - "no source types specified, float emulation will do nothing"); - return signalPassFailure(); + LLVM_DEBUG( + llvm::dbgs() + << "no source types specified, float emulation will do nothing\n"); + return; } if (llvm::is_contained(sourceTypes, targetType)) { - emitError(UnknownLoc::get(context), - "target type cannot be an unsupported source type"); + funcOp->emitError() << " target type cannot be an unsupported source type"; return signalPassFailure(); } @@ -100,8 +101,9 @@ void ConvertUnsupportedFloatArithPass::runOnOperation() { ConversionTarget target(*context); arith::populateEmulateUnsupportedFloatsLegality(target, converter); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { signalPassFailure(); + } } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index ec1bcce453d9..b1f99cc34f96 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -22,7 +22,6 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" namespace mlir::iree_compiler { diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 24513ed24393..eb4152f28f73 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -80,14 +80,12 @@ def ConvertBf16ToUInt16BuffersPass : let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16."; } -def ConvertUnsupportedFloatArithPass : InterfacePass<"iree-convert-unsupported-float-arith", "mlir::FunctionOpInterface"> { - let summary = "Convert arith operations on unsupported(source types) float types to the target type."; - let options = [ - ListOption<"sourceTypeStrs", "source-types", "std::string", - "MLIR types without arithmetic support on a given target">, - Option<"targetTypeStr", "target-type", "std::string", "\"f32\"", - "MLIR type to convert the unsupported source types to">, - ]; +def ConvertUnsupportedFloatArithPass + : InterfacePass<"iree-convert-unsupported-float-arith", + "mlir::FunctionOpInterface"> { + let summary = "Convert arith operations on unsupported(source types) float " + "types to the target type. Populates the source and target " + "based on the target architecture."; } def ConvertToDestinationPassingStylePass : diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_unsupported_float_arith.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_unsupported_float_arith.mlir index 87b553113516..699c6ddcd17c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/convert_unsupported_float_arith.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_unsupported_float_arith.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith{source-types=f8E4M3FNUZ,bf16 target-type=f32}))" %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith))" %s | FileCheck %s // CHECK-LABEL: func.func @negf_f8_unsupported // CHECK-SAME: (%[[ARG0:.*]]: f8E4M3FNUZ) -> f8E4M3FNUZ @@ -6,23 +6,37 @@ // CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32 // CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ // CHECK: return %[[TRUNC]] : f8E4M3FNUZ -func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ { +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target, , , , , , , , , , , , , , , ], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}> +func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes +{ hal.executable.target = #executable_target_rocm_hsaco_fb }{ %0 = arith.negf %arg0 : f8E4M3FNUZ return %0 : f8E4M3FNUZ } // ----- -func.func @bf16_expansion(%x: bf16) -> bf16 { -// CHECK-LABEL: @bf16_expansion -// CHECK-SAME: [[X:%.+]]: bf16 -// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16 -// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {{.*}} : bf16 to f32 -// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] {{.*}} : bf16 to f32 -// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32 -// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] {{.*}} : f32 to bf16 -// CHECK: return [[Y]] - %c = arith.constant 1.0 : bf16 - %y = arith.addf %x, %c : bf16 - func.return %y : bf16 +// CHECK-LABEL: func.func @expand_f8( +// CHECK-SAME: %[[ARG0:.*]]: f8E5M2FNUZ +// CHECK: %[[EXT0:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E5M2FNUZ to f32 +// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f8E5M2FNUZ +// CHECK: %[[EXT1:.*]] = arith.extf %[[CST]] {{.*}} : f8E5M2FNUZ to f32 +// CHECK: %[[SUM:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32 +// CHECK: %[[TRUNC:.*]] = arith.truncf %[[SUM]] {{.*}} : f32 to f8E5M2FNUZ +// CHECK: return %[[TRUNC]] : f8E5M2FNUZ +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target, , , , , , , , , , , , , , , ], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}> +func.func @expand_f8(%x: f8E5M2FNUZ) -> f8E5M2FNUZ attributes +{ hal.executable.target = #executable_target_rocm_hsaco_fb }{ + %c = arith.constant 1.0 : f8E5M2FNUZ + %y = arith.addf %x, %c : f8E5M2FNUZ + func.return %y : f8E5M2FNUZ +} + +// ----- + +// CHECK-LABEL: func.func @dont_expand_cpu_target +// CHECK: %[[NEG:.*]] = arith.negf {{.*}} : f8E4M3FNUZ +func.func @dont_expand_cpu_target(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes +{ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>}{ + %0 = arith.negf %arg0 : f8E4M3FNUZ + return %0 : f8E4M3FNUZ } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 75eb214c75c1..7e5831d20fd2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1126,14 +1126,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, if (forROCDL) { // convert to ROCDL. - funcPassManager.addPass([&]() { - ConvertUnsupportedFloatArithPassOptions options; - // Convert arith operation with the given `source types` to `target` - // type. - options.sourceTypeStrs = {"f8E4M3FNUZ", "f8E5M2FNUZ"}; - options.targetTypeStr = "f32"; - return createConvertUnsupportedFloatArithPass(options); - }); + funcPassManager.addPass(createConvertUnsupportedFloatArithPass); modulePassManager.addPass(createConvertToROCDLPass()); modulePassManager.addNestedPass( createROCDLAnnotateKernelForTranslationPass());