Skip to content

Commit

Permalink
[Codegen] Add support to emulate unsupported float type
Browse files Browse the repository at this point in the history
Converts arith operations on unsupported float types to f32.
  • Loading branch information
pashu123 committed Feb 12, 2025
1 parent 582dacd commit e8cc8a8
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 87 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ iree_compiler_cc_library(
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"ConvertUnsupportedFloatArithPass.cpp",
"ConvolutionToIGEMM.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ iree_cc_library(
"ConvertAccGEMMToGEMMPass.cpp"
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"BubbleUpOrdinalOps.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"ConvertUnsupportedFloatArithPass.cpp"
"ConvolutionToIGEMM.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
Expand Down
83 changes: 0 additions & 83 deletions compiler/src/iree/compiler/Codegen/Common/ConvertF8ArithToF32.cpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- BubbleUpOrdinalOpPass.cpp -----------------------------------------===//
//
// The workgroup count computation when using slices needs the ordinal
// annotation ops to be bubbled up as much as possible. This pass implements
// patterns to bubble these operations up.
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

struct ConvertUnsupportedFloatArithPass final
: public impl::ConvertUnsupportedFloatArithPassBase<
ConvertUnsupportedFloatArithPass> {
void runOnOperation() override;
using Base::Base;
};

} // namespace
//

static LogicalResult ParseFromOption(MLIRContext *ctx,
ArrayRef<std::string> sourceTypeStrs,
std::string targetTypeStr,
SmallVectorImpl<Type> &sourceTypes,
Type &targetType) {

std::optional<FloatType> maybeTargetType =
arith::parseFloatType(ctx, targetTypeStr);
if (!maybeTargetType) {
emitError(UnknownLoc::get(ctx), "could not map target type '" +
targetTypeStr +
"' to a known floating-point type");
return failure();
}
targetType = *maybeTargetType;
for (StringRef sourceTypeStr : sourceTypeStrs) {
std::optional<FloatType> maybeSourceType =
arith::parseFloatType(ctx, sourceTypeStr);
if (!maybeSourceType) {
emitError(UnknownLoc::get(ctx), "could not map source type '" +
sourceTypeStr +
"' to a known floating-point type");
return failure();
}
sourceTypes.push_back(*maybeSourceType);
}

return success();
}

void ConvertUnsupportedFloatArithPass::runOnOperation() {

MLIRContext *context = &getContext();
Operation *op = getOperation();
SmallVector<Type> sourceTypes;
Type targetType;

if (failed(ParseFromOption(context, sourceTypeStrs, targetTypeStr,
sourceTypes, targetType))) {
return signalPassFailure();
}

if (sourceTypes.empty()) {
(void)emitOptionalWarning(
std::nullopt,
"no source types specified, float emulation will do nothing");
return signalPassFailure();
}

if (llvm::is_contained(sourceTypes, targetType)) {
emitError(UnknownLoc::get(context),
"target type cannot be an unsupported source type");
return signalPassFailure();
}

TypeConverter converter;
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
targetType);
RewritePatternSet patterns(context);
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
ConversionTarget target(*context);
arith::populateEmulateUnsupportedFloatsLegality(target, converter);

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}

} // namespace mlir::iree_compiler
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler {
Expand Down
11 changes: 8 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,14 @@ def ConvertBf16ToUInt16BuffersPass :
let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16.";
}

def ConvertF8ArithToF32Pass :
Pass<"iree-codegen-convert-f8-to-f32-buffers", ""> {
let summary = "Convert f8 buffer ops and conversions to simulated behavior with f32.";
def ConvertUnsupportedFloatArithPass : InterfacePass<"iree-convert-unsupported-float-arith", "mlir::FunctionOpInterface"> {
let summary = "Convert arith operations on unsupported(source types) float types to the target type.";
let options = [
ListOption<"sourceTypeStrs", "source-types", "std::string",
"MLIR types without arithmetic support on a given target">,
Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
"MLIR type to convert the unsupported source types to">,
];
}

def ConvertToDestinationPassingStylePass :
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_lit_test_suite(
"convert_bf16_to_uint16_buffers.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
"convert_unsupported_float_arith.mlir",
"convolution_to_igemm.mlir",
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
"convert_unsupported_float_arith.mlir"
"convolution_to_igemm.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith{source-types=f8E4M3FNUZ,bf16 target-type=f32}))" %s | FileCheck %s

// CHECK-LABEL: func.func @negf_f8_unsupported
// CHECK-SAME: (%[[ARG0:.*]]: f8E4M3FNUZ) -> f8E4M3FNUZ
// CHECK: %[[EXT:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E4M3FNUZ to f32
// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ
func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ {
%0 = arith.negf %arg0 : f8E4M3FNUZ
return %0 : f8E4M3FNUZ
}

// -----

func.func @bf16_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @bf16_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {{.*}} : bf16 to f32
// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] {{.*}} : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] {{.*}} : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
func.return %y : bf16
}
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,14 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,

if (forROCDL) {
// convert to ROCDL.
funcPassManager.addPass([&]() {
ConvertUnsupportedFloatArithPassOptions options;
// Convert arith operation with the given `source types` to `target`
// type.
options.sourceTypeStrs = {"f8E4M3FNUZ", "f8E5M2FNUZ"};
options.targetTypeStr = "f32";
return createConvertUnsupportedFloatArithPass(options);
});
modulePassManager.addPass(createConvertToROCDLPass());
modulePassManager.addNestedPass<LLVM::LLVMFuncOp>(
createROCDLAnnotateKernelForTranslationPass());
Expand Down

0 comments on commit e8cc8a8

Please sign in to comment.