Skip to content

Commit

Permalink
const prop: round+saturate cast from fp to int types (#2593)
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen authored Oct 30, 2023
1 parent a82ddb4 commit 0f021bc
Show file tree
Hide file tree
Showing 11 changed files with 360 additions and 131 deletions.
9 changes: 9 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ std::string mtriple; // common for both
std::string mcpu; // common for both
std::string march; // common for both
InstrumentStages instrumentStage; // common for both
bool onnxConstPropRoundFPToInt; // common for both
int onnxConstPropExpansionBound; // common for both
std::vector<std::string> onnxConstPropDisablePatterns; // common for both
bool enableONNXHybridPass; // common for both
Expand Down Expand Up @@ -156,6 +157,14 @@ static llvm::cl::opt<InstrumentStages, true> instrumentStageOpt(
APPLY_TO_ACCELERATORS(ACCEL_INSTRUMENTSTAGE_CL_ENUM)),
llvm::cl::init(Onnx), llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> onnxConstPropRoundFPToIntOpt(
"onnx-const-prop-round-fp-to-int",
llvm::cl::desc("If true constant propagates onnx.Cast from a floating "
"point type to an integer type by rounding to nearest, "
"ties to even. If false truncates towards zero."),
llvm::cl::location(onnxConstPropRoundFPToInt), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<int, true> onnxConstPropExpansionBoundOpt(
"onnx-const-prop-expansion-bound",
llvm::cl::desc("ONNX dialect constant propagation maximum expansion factor."
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ extern std::string mtriple; // common for both
extern std::string mcpu; // common for both
extern std::string march; // common for both
extern InstrumentStages instrumentStage; // common for both
extern bool onnxConstPropRoundFPToInt; // common for both
extern int onnxConstPropExpansionBound; // common for both
extern std::vector<std::string> onnxConstPropDisablePatterns; // common for both
extern bool enableONNXHybridPass; // common for both
Expand Down
5 changes: 3 additions & 2 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ namespace onnx_mlir {
void configurePasses() {
// Set global vector machine support.
VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, "");
configureConstPropONNXToONNXPass(onnxConstPropExpansionBound,
onnxConstPropDisablePatterns, disableConstantProp);
configureConstPropONNXToONNXPass(onnxConstPropRoundFPToInt,
onnxConstPropExpansionBound, onnxConstPropDisablePatterns,
disableConstantProp);
configureOnnxToKrnlLoweringPass(optReport == OptReport::Parallel,
enableParallel, optReport == OptReport::Simd, !disableSimdOption);
}
Expand Down
259 changes: 204 additions & 55 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,80 +276,213 @@ ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs,
});
}

ElementsAttr ElementsAttrBuilder::castElementType(
ElementsAttr elms, Type newElementType) {
if (auto ftype = dyn_cast<FloatType>(newElementType)) {
// TODO: Consider saturating when ftype has no infinity:
// saturate=APFloat::getInf(ftype.getFloatSemantics()).isNaN()
return castToFPElementType(elms, ftype);
}
if (auto itype = dyn_cast<IntegerType>(newElementType)) {
return castToIntElementType(elms, itype);
}
llvm_unreachable("unsupported newElementType");
}

namespace {
using ElementsTransformer = std::function<void(llvm::MutableArrayRef<WideNum>)>;

ElementsTransformer composeTransforms(
ElementsTransformer first, ElementsTransformer second) {
if (first == nullptr)
return second;
else
return [fst = std::move(first), snd = std::move(second)](
MutableArrayRef<WideNum> dst) {
fst(dst);
snd(dst);
};
// Rounds (ties to even) and saturates (out of range numbers become MIN or MAX).
// Returns zero if from is NaN, like llvm::APFloat::convertToInteger().
// From must be a floating point type (double, float, float_16, float_8e5m2).
// To must be an integer type with size <= size(long), i.e., bitwidth <= 64.
//
// TODO: consider making it configurable whether to convert NaN to
// number farthest from zero (like X86 SSE)
// or just highest bit set (like CUDA) or zero
//
// TODO: optimize w/X86 SSE instructions https://stackoverflow.com/a/47347224
//
template <bool TRUNCATE, typename TO>
TO convertIntFromDouble(double from, TO min, TO max) {
if (std::isnan(from))
return 0;
if (from < static_cast<double>(min))
return min;
// static_cast<double>(max)) can round to a larger number
// so return max if from is greater or equal, not just if greater
if (from >= max)
return max;

if (TRUNCATE)
return static_cast<TO>(from);

// llrint recommendation: https://stackoverflow.com/a/47347224
// rounds to nearest, ties to even, in the default rounding mode
using llrintType = decltype(llrint(from));
if constexpr (std::is_same_v<TO, uint64_t>) {
static_assert(
sizeof(llrintType) >= sizeof(TO), "insufficient llrint range");
// llrintType is int64_t which doesn't cover the numeric range of uint64_t
// so we work around this by breaking the range into 2 as follows:
uint64_t mid = uint64_t(1) << 63; // middle of uint64_t numeric range
if (from < mid) {
// from is inside llrint's numerical range [-2^63, 2^63)
return llrint(from);
} else {
// subtract and add to translate into and llrint's numeric range and back
return mid + llrint(from - mid);
}
} else {
// llrintType covers the numeric range of TO, namely llrintType is int64_t
// and TO is int64_t or a narrower signed or unsigned type
static_assert(sizeof(llrintType) > sizeof(TO) ||
(sizeof(llrintType) == sizeof(TO) &&
std::numeric_limits<TO>::is_signed),
"insufficient llrint range");
return llrint(from);
}
}

template <typename SrcT, typename DstT>
struct Caster {
static inline constexpr DstT eval(SrcT src) { return static_cast<DstT>(src); }
};
template <bool TRUNCATE, typename TO>
auto convertIntFromFP(TO min, TO max) {
return [min, max](WideNum n) -> WideNum {
double from = n.narrow<BType::DOUBLE>();
TO to = convertIntFromDouble<TRUNCATE, TO>(from, min, max);
return WideNum::widen<toBType<TO>>(to);
};
}

template <typename SrcT, typename DstT>
using WideCaster = WideNumWrappedFunction<Caster<SrcT, DstT>>;

auto wideCaster(BType src, BType dst) -> WideNum (*)(WideNum) {
constexpr BType DBL = BType::DOUBLE, I64 = BType::INT64, U64 = BType::UINT64;
// clang-format off
if (src == DBL && dst == I64) return WideCaster<double, int64_t>::eval;
if (src == DBL && dst == U64) return WideCaster<double, uint64_t>::eval;
if (src == I64 && dst == DBL) return WideCaster<int64_t, double>::eval;
if (src == I64 && dst == U64) return WideCaster<int64_t, uint64_t>::eval;
if (src == U64 && dst == DBL) return WideCaster<uint64_t, double>::eval;
if (src == U64 && dst == I64) return WideCaster<uint64_t, int64_t>::eval;
// clang-format on
llvm_unreachable("wideCaster must be called with 2 different wide types");
template <typename FROM>
WideNum isWideNonZero(WideNum n) {
return WideNum::widen<BType::BOOL>(n.narrow<toBType<FROM>>() != 0);
}

template <typename TO, typename FROM>
WideNum wideCast(WideNum n) {
return WideNum::widen<toBType<TO>>(
static_cast<TO>(n.narrow<toBType<FROM>>()));
};

template <typename FROM>
double wideToDouble(WideNum n) {
return static_cast<double>(n.narrow<toBType<FROM>>());
};

} // namespace

ElementsAttr ElementsAttrBuilder::castElementType(
ElementsAttr elms, Type newElementType) {
ElementsAttr ElementsAttrBuilder::castToIntElementType(
ElementsAttr elms, IntegerType newElementType, bool round) {
Type oldElementType = elms.getElementType();
if (newElementType == oldElementType)
return elms;

ElementsProperties props = getElementsProperties(elms);
Transformer transformer;
if (newElementType.isInteger(1)) {
// Bool: +/-zero cast to 0, everything else including NaN cast to 1.
transformer = wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) {
using cpptype = decltype(wideZero);
return functionTransformer(isWideNonZero<cpptype>);
});
} else if (isa<FloatType>(oldElementType)) {
constexpr bool ROUND = false, TRUNCATE = true;
unsigned width = newElementType.getWidth();
if (newElementType.isUnsigned()) {
uint64_t min = 0;
uint64_t max = std::numeric_limits<uint64_t>::max() >> (64 - width);
transformer = round ? functionTransformer(
convertIntFromFP<ROUND, uint64_t>(min, max))
: functionTransformer(
convertIntFromFP<TRUNCATE, uint64_t>(min, max));
} else {
int64_t min = std::numeric_limits<int64_t>::min() >> (64 - width);
int64_t max = std::numeric_limits<int64_t>::max() >> (64 - width);
transformer = round ? functionTransformer(
convertIntFromFP<ROUND, int64_t>(min, max))
: functionTransformer(
convertIntFromFP<TRUNCATE, int64_t>(min, max));
}
} else if (isa<IntegerType>(oldElementType)) {
// We assume that casts to other integer types don't intend to truncate the
// numeric range and we delay any truncation until the data is read and
// allow the untruncated numbers as inputs to any further transformations.
//
// TODO: Add configuration options to support other behaviors.
// See https://github.com/onnx/onnx-mlir/issues/2209
if (newElementType.isUnsigned() != oldElementType.isUnsignedInteger()) {
// DisposableElementsAttr requires transformation between integers with
// different signs.
// TODO: Consider relaxing the requirement and omit this transformation.
transformer = newElementType.isUnsigned()
? functionTransformer(wideCast<uint64_t, int64_t>)
: functionTransformer(wideCast<int64_t, uint64_t>);
} else {
ElementsProperties props = getElementsProperties(elms);
ShapedType newType = elms.getShapedType().clone(newElementType);
return create(newType, props.bufferBType, props.strides, props.buffer,
props.transformer);
}
} else {
llvm_unreachable("unsupported element type");
}
return doTransform(elms, newElementType, transformer);
}

ShapedType newType = elms.getShapedType().clone(newElementType);
BType newBType = btypeOfMlirType(newElementType);
BType oldBType = btypeOfMlirType(oldElementType);
BType newWideType = wideBTypeOfBType(newBType);
BType oldWideType = wideBTypeOfBType(oldBType);

auto transformer =
oldWideType == newWideType
? props.transformer
: composeTransforms(props.transformer,
functionTransformer(wideCaster(oldWideType, newWideType)));
return create(newType, props.bufferBType, props.strides, props.buffer,
std::move(transformer));
ElementsAttr ElementsAttrBuilder::castToFPElementType(
ElementsAttr elms, FloatType newElementType, bool saturate) {
Type oldElementType = elms.getElementType();
if (newElementType == oldElementType)
return elms;

return wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) {
using cpptype = decltype(wideZero);
Transformer transformer;
if (saturate) {
// Smallest is -max for all ONNX fp types.
const double max = APFloat::getLargest(newElementType.getFloatSemantics())
.convertToDouble();
// Note that we saturate by clipping which isn't 100% faithful to the
// onnx spec here: https://onnx.ai/onnx/technical/float8.html
// and here: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
// which, in the case of E4M3FNUZ and E5M2FNUZ, requires infinite values
// to saturate to NaN, whereas we saturate them to smallest/largest with
// clipping. Our clipping implementation matches the reference
// implementation in onnx/reference/ops/op_cast.py.
// See https://github.com/onnx/onnx-mlir/issues/2369
//
// TODO: Change implementation to match the spec, or change the spec.
transformer = functionTransformer([max](WideNum n) {
double d = wideToDouble<cpptype>(n);
return WideNum::widen<BType::DOUBLE>(
// Order of operations is important to ensure NaN stays NaN:
d <= -max ? -max : (d >= max ? max : d));
});
} else if constexpr (std::is_integral_v<cpptype>) {
transformer = functionTransformer([](WideNum n) {
return WideNum::widen<BType::DOUBLE>(wideToDouble<cpptype>(n));
});
} else {
ElementsProperties props = getElementsProperties(elms);
ShapedType newType = elms.getShapedType().clone(newElementType);
return create(newType, props.bufferBType, props.strides, props.buffer,
props.transformer);
}
return doTransform(elms, newElementType, transformer);
});
}

ElementsAttr ElementsAttrBuilder::clip(
ElementsAttr elms, WideNum min, WideNum max) {
return wideZeroDispatchNonBool(elms.getElementType(), [&](auto wideZero) {
using cpptype = decltype(wideZero);
return doTransform(
elms, elms.getElementType(), functionTransformer([min, max](WideNum n) {
constexpr BType TAG = toBType<cpptype>;
cpptype x = n.narrow<TAG>();
if (x < min.narrow<TAG>())
return min;
if (x > max.narrow<TAG>())
return max;
return n;
}));
return transform(elms, elms.getElementType(), [min, max](WideNum n) {
constexpr BType TAG = toBType<cpptype>;
cpptype x = n.narrow<TAG>();
if (x < min.narrow<TAG>())
return min;
if (x > max.narrow<TAG>())
return max;
return n;
});
});
}

Expand Down Expand Up @@ -983,6 +1116,22 @@ ArrayBuffer<WideNum> ElementsAttrBuilder::getWideNumsAndExpandedStrides(
};
}

namespace {
using ElementsTransformer = std::function<void(llvm::MutableArrayRef<WideNum>)>;

ElementsTransformer composeTransforms(
ElementsTransformer first, ElementsTransformer second) {
if (first == nullptr)
return second;
else
return [fst = std::move(first), snd = std::move(second)](
MutableArrayRef<WideNum> dst) {
fst(dst);
snd(dst);
};
}
} // namespace

ElementsAttr ElementsAttrBuilder::doTransform(
ElementsAttr elms, Type transformedElementType, Transformer transformer) {
ShapedType transformedType =
Expand Down
23 changes: 22 additions & 1 deletion src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,33 @@ class ElementsAttrBuilder {
mlir::ElementsAttr where(mlir::ElementsAttr cond, mlir::ElementsAttr lhs,
mlir::ElementsAttr rhs, mlir::ShapedType combinedType);

// Returns an ElementsAttr with the elements cast to the given newElementType.
// Returns an ElementsAttr with the elements cast to the given newElementType
// with default choices for rounding (true) and saturation (false).
//
// Reuses elms' underlying data without a data copy.
mlir::ElementsAttr castElementType(
mlir::ElementsAttr elms, mlir::Type newElementType);

// Returns an ElementsAttr with the elements cast to the given intElementType.
//
// If round==true and elms has floating point numbers type then they are
// rounded to nearest integer, ties to even, otherwise they are truncated
// towards zero.
//
// Reuses elms' underlying data without a data copy.
mlir::ElementsAttr castToIntElementType(mlir::ElementsAttr elms,
mlir::IntegerType newElementType, bool round = true);

// Returns an ElementsAttr with the elements cast to the given fpElementType.
//
// If saturate==true and newElementType has +/-infinity then out of range
// numbers are cast to +/-infinity, otherwise they are clipped to the finite
// range.
//
// Reuses elms' underlying data without a data copy.
mlir::ElementsAttr castToFPElementType(mlir::ElementsAttr elms,
mlir::FloatType newElementType, bool saturate = false);

// Returns an ElementsAttr with the values clipped to the range [min, max].
//
// Reuses elms' underlying data without a data copy.
Expand Down
2 changes: 1 addition & 1 deletion src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::unique_ptr<mlir::Pass> createConvOptONNXToONNXPass(
std::unique_ptr<mlir::Pass> createShapeInferencePass();

// To configure ConstPropONNXToONNXPass at program start.
void configureConstPropONNXToONNXPass(int expansionBound,
void configureConstPropONNXToONNXPass(bool roundFPToInt, int expansionBound,
llvm::ArrayRef<std::string> disabledPatterns, bool constantPropIsDisabled);

std::unique_ptr<mlir::Pass> createConstPropONNXToONNXPass();
Expand Down
Loading

0 comments on commit 0f021bc

Please sign in to comment.