diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 63d2d3f0e9..def9e25eb5 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -11,6 +11,7 @@ #include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/Threading.h" #include "llvm/ADT/STLExtras.h" #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" @@ -187,18 +188,21 @@ ElementsAttr ElementsAttrBuilder::fromWideNums( // demonstrates a speedup. ElementsAttr ElementsAttrBuilder::combine(ElementsAttr lhs, ElementsAttr rhs, ShapedType combinedType, WideNum (*combiner)(WideNum, WideNum)) { + MLIRContext *ctx = lhs.getElementType().getContext(); if (lhs.isSplat()) { WideNum lhsNum = getElementsSplatWideNum(lhs); return expandAndTransform(rhs, combinedType, functionTransformer( - [lhsNum, combiner](WideNum n) { return combiner(lhsNum, n); })); + [lhsNum, combiner](WideNum n) { return combiner(lhsNum, n); }, + ctx)); } if (rhs.isSplat()) { WideNum rhsNum = getElementsSplatWideNum(rhs); return expandAndTransform(lhs, combinedType, functionTransformer( - [rhsNum, combiner](WideNum n) { return combiner(n, rhsNum); })); + [rhsNum, combiner](WideNum n) { return combiner(n, rhsNum); }, + ctx)); } auto combinedShape = combinedType.getShape(); @@ -231,6 +235,7 @@ ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs, assert(lhs.getElementType() == rhs.getElementType()); assert(lhs.getElementType() == combinedType.getElementType()); + MLIRContext *ctx = lhs.getElementType().getContext(); if (cond.isSplat()) { bool condBool = getElementsSplatWideNum(cond).u64; return expand(condBool ? lhs : rhs, combinedType.getShape()); @@ -241,7 +246,8 @@ ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs, WideNum rhsNum = getElementsSplatWideNum(rhs); return expandAndTransform(cond, combinedType, functionTransformer( - [lhsNum, rhsNum](WideNum n) { return n.u64 ? lhsNum : rhsNum; })); + [lhsNum, rhsNum](WideNum n) { return n.u64 ? lhsNum : rhsNum; }, + ctx)); } auto combinedShape = combinedType.getShape(); @@ -373,6 +379,7 @@ double wideToDouble(WideNum n) { ElementsAttr ElementsAttrBuilder::castToIntElementType( ElementsAttr elms, IntegerType newElementType, bool round) { Type oldElementType = elms.getElementType(); + MLIRContext *ctx = oldElementType.getContext(); if (newElementType == oldElementType) return elms; @@ -381,7 +388,7 @@ ElementsAttr ElementsAttrBuilder::castToIntElementType( // Bool: +/-zero cast to 0, everything else including NaN cast to 1. transformer = wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) { using cpptype = decltype(wideZero); - return functionTransformer(isWideNonZero); + return functionTransformer(isWideNonZero, ctx); }); } else if (isa(oldElementType)) { constexpr bool ROUND = false, TRUNCATE = true; @@ -389,17 +396,19 @@ ElementsAttr ElementsAttrBuilder::castToIntElementType( if (newElementType.isUnsigned()) { uint64_t min = 0; uint64_t max = std::numeric_limits::max() >> (64 - width); - transformer = round ? functionTransformer( - convertIntFromFP(min, max)) - : functionTransformer( - convertIntFromFP(min, max)); + transformer = + round ? functionTransformer( + convertIntFromFP(min, max), ctx) + : functionTransformer( + convertIntFromFP(min, max), ctx); } else { int64_t min = std::numeric_limits::min() >> (64 - width); int64_t max = std::numeric_limits::max() >> (64 - width); - transformer = round ? functionTransformer( - convertIntFromFP(min, max)) - : functionTransformer( - convertIntFromFP(min, max)); + transformer = + round ? functionTransformer( + convertIntFromFP(min, max), ctx) + : functionTransformer( + convertIntFromFP(min, max), ctx); } } else if (isa(oldElementType)) { // We assume that casts to other integer types don't intend to truncate the @@ -413,8 +422,8 @@ ElementsAttr ElementsAttrBuilder::castToIntElementType( // different signs. // TODO: Consider relaxing the requirement and omit this transformation. transformer = newElementType.isUnsigned() - ? functionTransformer(wideCast) - : functionTransformer(wideCast); + ? functionTransformer(wideCast, ctx) + : functionTransformer(wideCast, ctx); } else { ElementsProperties props = getElementsProperties(elms); ShapedType newType = elms.getShapedType().clone(newElementType); @@ -433,6 +442,7 @@ ElementsAttr ElementsAttrBuilder::castToFPElementType( if (newElementType == oldElementType) return elms; + MLIRContext *ctx = oldElementType.getContext(); return wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) { using cpptype = decltype(wideZero); Transformer transformer; @@ -450,16 +460,20 @@ ElementsAttr ElementsAttrBuilder::castToFPElementType( // See https://github.com/onnx/onnx-mlir/issues/2369 // // TODO: Change implementation to match the spec, or change the spec. - transformer = functionTransformer([max](WideNum n) { - double d = wideToDouble(n); - return WideNum::widen( - // Order of operations is important to ensure NaN stays NaN: - d <= -max ? -max : (d >= max ? max : d)); - }); + transformer = functionTransformer( + [max](WideNum n) { + double d = wideToDouble(n); + return WideNum::widen( + // Order of operations is important to ensure NaN stays NaN: + d <= -max ? -max : (d >= max ? max : d)); + }, + ctx); } else if constexpr (std::is_integral_v) { - transformer = functionTransformer([](WideNum n) { - return WideNum::widen(wideToDouble(n)); - }); + transformer = functionTransformer( + [](WideNum n) { + return WideNum::widen(wideToDouble(n)); + }, + ctx); } else { ElementsProperties props = getElementsProperties(elms); ShapedType newType = elms.getShapedType().clone(newElementType); @@ -849,6 +863,8 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms, if (axes.empty()) return elms; + Type elementType = elms.getElementType(); + MLIRContext *ctx = elementType.getContext(); SmallVector sortedAxes(axes); std::sort(sortedAxes.begin(), sortedAxes.end()); assert( @@ -885,22 +901,45 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms, ShapedType reducedType = type.clone(reducedShape); return fromWideNums(reducedType, [&](MutableArrayRef dstNums) { - // Traverse and populate each element d in dstNums. - for (auto &idxoffs : StridesRange<1>(reducedShape, {reducedStrides})) { - WideNum &d = dstNums[idxoffs.flattenedIndex]; - int64_t srcPos = idxoffs[0]; - // Traverse all the elements that reduce together into d. - // srcNums elements may be repeated if there are zeros in axesStrides. - StridesRange<1> axesRange(axesShape, {axesStrides}); - auto axesIter = axesRange.begin(); - auto axesEnd = axesRange.end(); - assert(axesIter->at(0) == 0 && "initial src offset must be zero"); - d = srcNums.get()[srcPos]; - while (++axesIter != axesEnd) { - int64_t srcOffset = axesIter->at(0); - d = reducer(d, srcNums.get()[srcPos + srcOffset]); + StridesRange<1> sRange(reducedShape, {reducedStrides}); + SmallVector, 4> batch; + for (auto &idxoffs : sRange) + batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0])); + + std::mutex mtx; + size_t beginOffset = 0; + auto fetchBatch = [&](size_t threadNumber) { + const std::lock_guard lock(mtx); + size_t batchSize = batch.size() / ctx->getNumThreads(); + size_t batchSizeMod = batch.size() % ctx->getNumThreads(); + if (threadNumber < batchSizeMod) + batchSize += 1; + auto batchBegin = batch.begin() + beginOffset; + auto batchEnd = batchBegin + batchSize; + beginOffset += batchSize; + return llvm::make_range(batchBegin, batchEnd); + }; + + auto work = [&](size_t threadNumber) { + auto batch = fetchBatch(threadNumber); + // Traverse and populate each element d in dstNums. + for (auto b : batch) { + WideNum &d = dstNums[b.first]; + int64_t srcPos = b.second; + // Traverse all the elements that reduce together into d. + // srcNums elements may be repeated if there are zeros in axesStrides. + StridesRange<1> axesRange(axesShape, {axesStrides}); + auto axesIter = axesRange.begin(); + auto axesEnd = axesRange.end(); + assert(axesIter->at(0) == 0 && "initial src offset must be zero"); + d = srcNums.get()[srcPos]; + while (++axesIter != axesEnd) { + int64_t srcOffset = axesIter->at(0); + d = reducer(d, srcNums.get()[srcPos + srcOffset]); + } } - } + }; + parallelFor(ctx, 0, ctx->getNumThreads(), work); }); } diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index f7276b6ebb..55ae72a8ce 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -11,6 +11,8 @@ #ifndef ONNX_MLIR_ELEM_ATTR_BUILDER_H #define ONNX_MLIR_ELEM_ATTR_BUILDER_H +#include "mlir/IR/Threading.h" + #include "src/Dialect/ONNX/ElementsAttr/BType.hpp" #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" #include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp" @@ -91,8 +93,9 @@ class ElementsAttrBuilder { template mlir::ElementsAttr transform(mlir::ElementsAttr elms, mlir::Type transformedElementType, Function fun) { + mlir::MLIRContext *ctx = elms.getElementType().getContext(); return doTransform( - elms, transformedElementType, functionTransformer(std::move(fun))); + elms, transformedElementType, functionTransformer(std::move(fun), ctx)); } // Returns an ElementsAttr that is the result of applying a binary function @@ -244,10 +247,34 @@ class ElementsAttrBuilder { // Constructs a transformer that changes every element to the result of // applying the given function to the element. template - static inline Transformer functionTransformer(Function fun) { - return [fun = std::move(fun)](llvm::MutableArrayRef data) -> void { - for (WideNum &n : data) - n = fun(n); + static inline Transformer functionTransformer( + Function fun, mlir::MLIRContext *ctx) { + return [fun = std::move(fun), ctx]( + llvm::MutableArrayRef data) -> void { + std::mutex mtx; + size_t beginOffset = 0; + auto fetchBatch = [&](size_t threadNumber) { + const std::lock_guard lock(mtx); + size_t batchSize = data.size() / ctx->getNumThreads(); + size_t batchSizeMod = data.size() % ctx->getNumThreads(); + if (threadNumber < batchSizeMod) + batchSize += 1; + auto batchBegin = data.begin() + beginOffset; + auto batchEnd = batchBegin + batchSize; + beginOffset += batchSize; + return llvm::make_range(batchBegin, batchEnd); + }; + + auto work = [&](size_t threadNumber) { + auto batch = fetchBatch(threadNumber); + + for (WideNum &n : batch) + n = fun(n); + }; + parallelFor(ctx, 0, ctx->getNumThreads(), work); + + // for (WideNum &n : data) + // n = fun(n); }; }