-
Notifications
You must be signed in to change notification settings - Fork 663
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Codegen] Add support to emulate unsupported float type
Converts arith operations on unsupported float types to f32.
- Loading branch information
Showing
10 changed files
with
154 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 0 additions & 83 deletions
83
compiler/src/iree/compiler/Codegen/Common/ConvertF8ArithToF32.cpp
This file was deleted.
Oops, something went wrong.
105 changes: 105 additions & 0 deletions
105
compiler/src/iree/compiler/Codegen/Common/ConvertUnsupportedFloatArithPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Copyright 2025 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
//===- BubbleUpOrdinalOpPass.cpp -----------------------------------------===// | ||
// | ||
// The workgroup count computation when using slices needs the ordinal | ||
// annotation ops to be bubbled up as much as possible. This pass implements | ||
// patterns to bubble these operations up. | ||
// | ||
//===---------------------------------------------------------------------===// | ||
|
||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Codegen/Common/Transforms.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
struct ConvertUnsupportedFloatArithPass final | ||
: public impl::ConvertUnsupportedFloatArithPassBase< | ||
ConvertUnsupportedFloatArithPass> { | ||
void runOnOperation() override; | ||
using Base::Base; | ||
}; | ||
|
||
} // namespace | ||
// | ||
|
||
static LogicalResult ParseFromOption(MLIRContext *ctx, | ||
ArrayRef<std::string> sourceTypeStrs, | ||
std::string targetTypeStr, | ||
SmallVectorImpl<Type> &sourceTypes, | ||
Type &targetType) { | ||
|
||
std::optional<FloatType> maybeTargetType = | ||
arith::parseFloatType(ctx, targetTypeStr); | ||
if (!maybeTargetType) { | ||
emitError(UnknownLoc::get(ctx), "could not map target type '" + | ||
targetTypeStr + | ||
"' to a known floating-point type"); | ||
return failure(); | ||
} | ||
targetType = *maybeTargetType; | ||
for (StringRef sourceTypeStr : sourceTypeStrs) { | ||
std::optional<FloatType> maybeSourceType = | ||
arith::parseFloatType(ctx, sourceTypeStr); | ||
if (!maybeSourceType) { | ||
emitError(UnknownLoc::get(ctx), "could not map source type '" + | ||
sourceTypeStr + | ||
"' to a known floating-point type"); | ||
return failure(); | ||
} | ||
sourceTypes.push_back(*maybeSourceType); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
void ConvertUnsupportedFloatArithPass::runOnOperation() { | ||
|
||
MLIRContext *context = &getContext(); | ||
Operation *op = getOperation(); | ||
SmallVector<Type> sourceTypes; | ||
Type targetType; | ||
|
||
if (failed(ParseFromOption(context, sourceTypeStrs, targetTypeStr, | ||
sourceTypes, targetType))) { | ||
return signalPassFailure(); | ||
} | ||
|
||
if (sourceTypes.empty()) { | ||
(void)emitOptionalWarning( | ||
std::nullopt, | ||
"no source types specified, float emulation will do nothing"); | ||
return signalPassFailure(); | ||
} | ||
|
||
if (llvm::is_contained(sourceTypes, targetType)) { | ||
emitError(UnknownLoc::get(context), | ||
"target type cannot be an unsupported source type"); | ||
return signalPassFailure(); | ||
} | ||
|
||
TypeConverter converter; | ||
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes, | ||
targetType); | ||
RewritePatternSet patterns(context); | ||
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter); | ||
ConversionTarget target(*context); | ||
arith::populateEmulateUnsupportedFloatsLegality(target, converter); | ||
|
||
if (failed(applyPartialConversion(op, target, std::move(patterns)))) | ||
signalPassFailure(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
compiler/src/iree/compiler/Codegen/Common/test/convert_unsupported_float_arith.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith{source-types=f8E4M3FNUZ,bf16 target-type=f32}))" %s | FileCheck %s | ||
|
||
// CHECK-LABEL: func.func @negf_f8_unsupported | ||
// CHECK-SAME: (%[[ARG0:.*]]: f8E4M3FNUZ) -> f8E4M3FNUZ | ||
// CHECK: %[[EXT:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E4M3FNUZ to f32 | ||
// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32 | ||
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ | ||
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ | ||
func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ { | ||
%0 = arith.negf %arg0 : f8E4M3FNUZ | ||
return %0 : f8E4M3FNUZ | ||
} | ||
|
||
// ----- | ||
|
||
func.func @bf16_expansion(%x: bf16) -> bf16 { | ||
// CHECK-LABEL: @bf16_expansion | ||
// CHECK-SAME: [[X:%.+]]: bf16 | ||
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16 | ||
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] {{.*}} : bf16 to f32 | ||
// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] {{.*}} : bf16 to f32 | ||
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32 | ||
// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] {{.*}} : f32 to bf16 | ||
// CHECK: return [[Y]] | ||
%c = arith.constant 1.0 : bf16 | ||
%y = arith.addf %x, %c : bf16 | ||
func.return %y : bf16 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters