Skip to content

Commit

Permalink
Parallelize compilation of ConstProp.
Browse files Browse the repository at this point in the history
Signed-off-by: Haruki Imai <[email protected]>
  • Loading branch information
imaihal committed Jan 14, 2025
1 parent 0183ad9 commit eb4a777
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 43 deletions.
115 changes: 77 additions & 38 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp"

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h"

#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
Expand Down Expand Up @@ -187,18 +188,21 @@ ElementsAttr ElementsAttrBuilder::fromWideNums(
// demonstrates a speedup.
ElementsAttr ElementsAttrBuilder::combine(ElementsAttr lhs, ElementsAttr rhs,
ShapedType combinedType, WideNum (*combiner)(WideNum, WideNum)) {
MLIRContext *ctx = lhs.getElementType().getContext();
if (lhs.isSplat()) {
WideNum lhsNum = getElementsSplatWideNum(lhs);
return expandAndTransform(rhs, combinedType,
functionTransformer(
[lhsNum, combiner](WideNum n) { return combiner(lhsNum, n); }));
[lhsNum, combiner](WideNum n) { return combiner(lhsNum, n); },
ctx));
}

if (rhs.isSplat()) {
WideNum rhsNum = getElementsSplatWideNum(rhs);
return expandAndTransform(lhs, combinedType,
functionTransformer(
[rhsNum, combiner](WideNum n) { return combiner(n, rhsNum); }));
[rhsNum, combiner](WideNum n) { return combiner(n, rhsNum); },
ctx));
}

auto combinedShape = combinedType.getShape();
Expand Down Expand Up @@ -231,6 +235,7 @@ ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs,
assert(lhs.getElementType() == rhs.getElementType());
assert(lhs.getElementType() == combinedType.getElementType());

MLIRContext *ctx = lhs.getElementType().getContext();
if (cond.isSplat()) {
bool condBool = getElementsSplatWideNum(cond).u64;
return expand(condBool ? lhs : rhs, combinedType.getShape());
Expand All @@ -241,7 +246,8 @@ ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs,
WideNum rhsNum = getElementsSplatWideNum(rhs);
return expandAndTransform(cond, combinedType,
functionTransformer(
[lhsNum, rhsNum](WideNum n) { return n.u64 ? lhsNum : rhsNum; }));
[lhsNum, rhsNum](WideNum n) { return n.u64 ? lhsNum : rhsNum; },
ctx));
}

auto combinedShape = combinedType.getShape();
Expand Down Expand Up @@ -373,6 +379,7 @@ double wideToDouble(WideNum n) {
ElementsAttr ElementsAttrBuilder::castToIntElementType(
ElementsAttr elms, IntegerType newElementType, bool round) {
Type oldElementType = elms.getElementType();
MLIRContext *ctx = oldElementType.getContext();
if (newElementType == oldElementType)
return elms;

Expand All @@ -381,25 +388,27 @@ ElementsAttr ElementsAttrBuilder::castToIntElementType(
// Bool: +/-zero cast to 0, everything else including NaN cast to 1.
transformer = wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) {
using cpptype = decltype(wideZero);
return functionTransformer(isWideNonZero<cpptype>);
return functionTransformer(isWideNonZero<cpptype>, ctx);
});
} else if (isa<FloatType>(oldElementType)) {
constexpr bool ROUND = false, TRUNCATE = true;
unsigned width = newElementType.getWidth();
if (newElementType.isUnsigned()) {
uint64_t min = 0;
uint64_t max = std::numeric_limits<uint64_t>::max() >> (64 - width);
transformer = round ? functionTransformer(
convertIntFromFP<ROUND, uint64_t>(min, max))
: functionTransformer(
convertIntFromFP<TRUNCATE, uint64_t>(min, max));
transformer =
round ? functionTransformer(
convertIntFromFP<ROUND, uint64_t>(min, max), ctx)
: functionTransformer(
convertIntFromFP<TRUNCATE, uint64_t>(min, max), ctx);
} else {
int64_t min = std::numeric_limits<int64_t>::min() >> (64 - width);
int64_t max = std::numeric_limits<int64_t>::max() >> (64 - width);
transformer = round ? functionTransformer(
convertIntFromFP<ROUND, int64_t>(min, max))
: functionTransformer(
convertIntFromFP<TRUNCATE, int64_t>(min, max));
transformer =
round ? functionTransformer(
convertIntFromFP<ROUND, int64_t>(min, max), ctx)
: functionTransformer(
convertIntFromFP<TRUNCATE, int64_t>(min, max), ctx);
}
} else if (isa<IntegerType>(oldElementType)) {
// We assume that casts to other integer types don't intend to truncate the
Expand All @@ -413,8 +422,8 @@ ElementsAttr ElementsAttrBuilder::castToIntElementType(
// different signs.
// TODO: Consider relaxing the requirement and omit this transformation.
transformer = newElementType.isUnsigned()
? functionTransformer(wideCast<uint64_t, int64_t>)
: functionTransformer(wideCast<int64_t, uint64_t>);
? functionTransformer(wideCast<uint64_t, int64_t>, ctx)
: functionTransformer(wideCast<int64_t, uint64_t>, ctx);
} else {
ElementsProperties props = getElementsProperties(elms);
ShapedType newType = elms.getShapedType().clone(newElementType);
Expand All @@ -433,6 +442,7 @@ ElementsAttr ElementsAttrBuilder::castToFPElementType(
if (newElementType == oldElementType)
return elms;

MLIRContext *ctx = oldElementType.getContext();
return wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) {
using cpptype = decltype(wideZero);
Transformer transformer;
Expand All @@ -450,16 +460,20 @@ ElementsAttr ElementsAttrBuilder::castToFPElementType(
// See https://github.com/onnx/onnx-mlir/issues/2369
//
// TODO: Change implementation to match the spec, or change the spec.
transformer = functionTransformer([max](WideNum n) {
double d = wideToDouble<cpptype>(n);
return WideNum::widen<BType::DOUBLE>(
// Order of operations is important to ensure NaN stays NaN:
d <= -max ? -max : (d >= max ? max : d));
});
transformer = functionTransformer(
[max](WideNum n) {
double d = wideToDouble<cpptype>(n);
return WideNum::widen<BType::DOUBLE>(
// Order of operations is important to ensure NaN stays NaN:
d <= -max ? -max : (d >= max ? max : d));
},
ctx);
} else if constexpr (std::is_integral_v<cpptype>) {
transformer = functionTransformer([](WideNum n) {
return WideNum::widen<BType::DOUBLE>(wideToDouble<cpptype>(n));
});
transformer = functionTransformer(
[](WideNum n) {
return WideNum::widen<BType::DOUBLE>(wideToDouble<cpptype>(n));
},
ctx);
} else {
ElementsProperties props = getElementsProperties(elms);
ShapedType newType = elms.getShapedType().clone(newElementType);
Expand Down Expand Up @@ -849,6 +863,8 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
if (axes.empty())
return elms;

Type elementType = elms.getElementType();
MLIRContext *ctx = elementType.getContext();
SmallVector<unsigned, 4> sortedAxes(axes);
std::sort(sortedAxes.begin(), sortedAxes.end());
assert(
Expand Down Expand Up @@ -885,22 +901,45 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,

ShapedType reducedType = type.clone(reducedShape);
return fromWideNums(reducedType, [&](MutableArrayRef<WideNum> dstNums) {
// Traverse and populate each element d in dstNums.
for (auto &idxoffs : StridesRange<1>(reducedShape, {reducedStrides})) {
WideNum &d = dstNums[idxoffs.flattenedIndex];
int64_t srcPos = idxoffs[0];
// Traverse all the elements that reduce together into d.
// srcNums elements may be repeated if there are zeros in axesStrides.
StridesRange<1> axesRange(axesShape, {axesStrides});
auto axesIter = axesRange.begin();
auto axesEnd = axesRange.end();
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
d = srcNums.get()[srcPos];
while (++axesIter != axesEnd) {
int64_t srcOffset = axesIter->at(0);
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
StridesRange<1> sRange(reducedShape, {reducedStrides});
SmallVector<std::pair<int64_t, uint64_t>, 4> batch;
for (auto &idxoffs : sRange)
batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0]));

std::mutex mtx;
size_t beginOffset = 0;
auto fetchBatch = [&](size_t threadNumber) {
const std::lock_guard<std::mutex> lock(mtx);
size_t batchSize = batch.size() / ctx->getNumThreads();
size_t batchSizeMod = batch.size() % ctx->getNumThreads();
if (threadNumber < batchSizeMod)
batchSize += 1;
auto batchBegin = batch.begin() + beginOffset;
auto batchEnd = batchBegin + batchSize;
beginOffset += batchSize;
return llvm::make_range(batchBegin, batchEnd);
};

auto work = [&](size_t threadNumber) {
auto batch = fetchBatch(threadNumber);
// Traverse and populate each element d in dstNums.
for (auto b : batch) {
WideNum &d = dstNums[b.first];
int64_t srcPos = b.second;
// Traverse all the elements that reduce together into d.
// srcNums elements may be repeated if there are zeros in axesStrides.
StridesRange<1> axesRange(axesShape, {axesStrides});
auto axesIter = axesRange.begin();
auto axesEnd = axesRange.end();
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
d = srcNums.get()[srcPos];
while (++axesIter != axesEnd) {
int64_t srcOffset = axesIter->at(0);
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
}
}
}
};
parallelFor(ctx, 0, ctx->getNumThreads(), work);
});
}

Expand Down
37 changes: 32 additions & 5 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#ifndef ONNX_MLIR_ELEM_ATTR_BUILDER_H
#define ONNX_MLIR_ELEM_ATTR_BUILDER_H

#include "mlir/IR/Threading.h"

#include "src/Dialect/ONNX/ElementsAttr/BType.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp"
Expand Down Expand Up @@ -91,8 +93,9 @@ class ElementsAttrBuilder {
template <typename Function = WideNum (*)(WideNum)>
mlir::ElementsAttr transform(mlir::ElementsAttr elms,
mlir::Type transformedElementType, Function fun) {
mlir::MLIRContext *ctx = elms.getElementType().getContext();
return doTransform(
elms, transformedElementType, functionTransformer(std::move(fun)));
elms, transformedElementType, functionTransformer(std::move(fun), ctx));
}

// Returns an ElementsAttr that is the result of applying a binary function
Expand Down Expand Up @@ -244,10 +247,34 @@ class ElementsAttrBuilder {
// Constructs a transformer that changes every element to the result of
// applying the given function to the element.
template <typename Function = WideNum (*)(WideNum)>
static inline Transformer functionTransformer(Function fun) {
return [fun = std::move(fun)](llvm::MutableArrayRef<WideNum> data) -> void {
for (WideNum &n : data)
n = fun(n);
static inline Transformer functionTransformer(
Function fun, mlir::MLIRContext *ctx) {
return [fun = std::move(fun), ctx](
llvm::MutableArrayRef<WideNum> data) -> void {
std::mutex mtx;
size_t beginOffset = 0;
auto fetchBatch = [&](size_t threadNumber) {
const std::lock_guard<std::mutex> lock(mtx);
size_t batchSize = data.size() / ctx->getNumThreads();
size_t batchSizeMod = data.size() % ctx->getNumThreads();
if (threadNumber < batchSizeMod)
batchSize += 1;
auto batchBegin = data.begin() + beginOffset;
auto batchEnd = batchBegin + batchSize;
beginOffset += batchSize;
return llvm::make_range(batchBegin, batchEnd);
};

auto work = [&](size_t threadNumber) {
auto batch = fetchBatch(threadNumber);

for (WideNum &n : batch)
n = fun(n);
};
parallelFor(ctx, 0, ctx->getNumThreads(), work);

// for (WideNum &n : data)
// n = fun(n);
};
}

Expand Down

0 comments on commit eb4a777

Please sign in to comment.