Skip to content

Commit a49d00d

Browse files
Layer normalization lowering to KRNL (#2561)
* Layer normalization lowering to KRNL Signed-off-by: Alexandre Eichenberger <[email protected]> --------- Signed-off-by: Alexandre Eichenberger <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent ba785f9 commit a49d00d

17 files changed

+479
-24
lines changed

docs/SupportedONNXOps-cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 19. Limitatio
102102
| **LRN** |6 - * | | |
103103
| **LSTM** |7 - * | | |
104104
| **LabelEncoder** |none | | | |
105-
| **LayerNormalization** |none | | | |
105+
| **LayerNormalization** |17 - * | | |
106106
| **LeakyRelu** |6 - * | | |
107107
| **Less** |7 - * | | |
108108
| **LessOrEqual** |12 - * | | |

src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
273273
populateLoweringONNXUniqueOpPattern(patterns, typeConverter, ctx);
274274
// Neural network
275275
populateLoweringONNXConvOpPattern(patterns, typeConverter, ctx, enableParallel, opsForCall);
276-
populateLoweringONNXNormalizationOpPattern(patterns, typeConverter, ctx);
276+
populateLoweringONNXNormalizationOpPattern(patterns, typeConverter, ctx, enableSIMD);
277277
populateLoweringONNXPoolingOpPattern(patterns, typeConverter, ctx);
278278
// Recurrent neural network
279279
populateLoweringONNXGRUOpPattern(patterns, typeConverter, ctx);

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1396,7 +1396,7 @@ int64_t canBeVectorized(ShapeHelperType &shapeHelper, MDBuilder &create,
13961396
simdUnroll = 4;
13971397
else
13981398
simdUnroll = 8;
1399-
uVL = create.vec.SuitableUnrollFactor(vms, memRefType,
1399+
uVL = create.vec.computeSuitableUnrollFactor(vms, memRefType,
14001400
shapeHelper.getOutputDims(), collapsedInnermostLoops, simdUnroll,
14011401
/*canPad*/ true, estimatedSimdLoopTripCount);
14021402
LLVM_DEBUG({

src/Conversion/ONNXToKrnl/Math/Reduction.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,8 @@ struct ONNXReductionOpLowering : public OpConversionPattern<ONNXReductionOp> {
544544
}
545545
LLVM_DEBUG(llvm::dbgs()
546546
<< " SIMD: study with init unroll " << unroll << "\n");
547-
VL = create.vec.SuitableUnrollFactor(vms, memRefInType, inputDims,
548-
innermostLoopCollapse, unroll, /*canPad*/ false,
547+
VL = create.vec.computeSuitableUnrollFactor(vms, memRefInType,
548+
inputDims, innermostLoopCollapse, unroll, /*canPad*/ false,
549549
estimatedSimdLoopTripCount);
550550
LLVM_DEBUG(llvm::dbgs() << " SIMD: " << innermostLoopCollapse
551551
<< " loops, VL " << VL << "\n");

src/Conversion/ONNXToKrnl/NN/Normalization.cpp

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,19 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
16+
#include "src/Dialect/Krnl/DialectBuilder.hpp"
17+
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
18+
19+
#define DEBUG_TYPE "lowering-to-krnl"
1620

1721
using namespace mlir;
1822

1923
namespace onnx_mlir {
2024

25+
//===----------------------------------------------------------------------===//
26+
// Batch Norm
27+
//===----------------------------------------------------------------------===//
28+
2129
struct ONNXBatchNormalizationInferenceModeOpLowering
2230
: public OpConversionPattern<ONNXBatchNormalizationInferenceModeOp> {
2331
ONNXBatchNormalizationInferenceModeOpLowering(
@@ -136,6 +144,10 @@ struct ONNXBatchNormalizationInferenceModeOpLowering
136144
}
137145
};
138146

147+
//===----------------------------------------------------------------------===//
148+
// Instance Normalization
149+
//===----------------------------------------------------------------------===//
150+
139151
struct ONNXInstanceNormalizationOpLowering
140152
: public OpConversionPattern<ONNXInstanceNormalizationOp> {
141153
ONNXInstanceNormalizationOpLowering(
@@ -285,11 +297,103 @@ struct ONNXInstanceNormalizationOpLowering
285297
}
286298
};
287299

300+
//===----------------------------------------------------------------------===//
301+
// Layer Normalization
302+
//===----------------------------------------------------------------------===//
303+
304+
using MDBuilder = MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl,
305+
MemRefBuilder, MathBuilder, VectorBuilder, OnnxBuilder>;
306+
307+
// Generate the original ONNX operations. This is the unoptimized path.
308+
// TODO: conversions of types are not handled.
309+
LogicalResult generateONNXLayerNormalizationOpONNXCode(
310+
ConversionPatternRewriter &rewriter, Location loc,
311+
ONNXLayerNormalizationOp lnOp) {
312+
MDBuilder create(rewriter, loc);
313+
Value X = lnOp.getX(); // Original value, not translated.
314+
TensorType XType = X.getType().cast<TensorType>();
315+
Type elementType = XType.getElementType();
316+
int64_t XRank = XType.getRank();
317+
int64_t axis = getAxisInRange(lnOp.getAxis(), XRank);
318+
// Get epsilon
319+
FloatAttr epsilonAttr = lnOp.getEpsilonAttr();
320+
DenseElementsAttr epsilonDenseAttr =
321+
onnx_mlir::createDenseElementsAttrFromFloatAttr(
322+
rewriter, elementType, epsilonAttr);
323+
Value epsilon = create.onnx.constant(epsilonDenseAttr);
324+
325+
// Create reduction axes array.
326+
llvm::SmallVector<int64_t, 4> axesIntArray, reductionShape;
327+
for (int64_t r = 0; r < axis; ++r)
328+
reductionShape.emplace_back(XType.getShape()[r]);
329+
for (int64_t r = axis; r < XRank; ++r) {
330+
reductionShape.emplace_back(1);
331+
axesIntArray.emplace_back(r);
332+
}
333+
Value axes =
334+
create.onnx.constant(create.getBuilder().getI64TensorAttr(axesIntArray));
335+
TensorType reductionType = RankedTensorType::get(reductionShape, elementType);
336+
// Reduction of input
337+
Value meanOfX = create.onnx.reduceMean(reductionType, X, axes);
338+
Value pow2OfMeanOfX = create.onnx.mul(meanOfX, meanOfX);
339+
Value XPow2 = create.onnx.mul(X, X);
340+
Value meanOfXPow2 = create.onnx.reduceMean(reductionType, XPow2, axes);
341+
Value var = create.onnx.sub(meanOfXPow2, pow2OfMeanOfX);
342+
Value varWithEpsilon = create.onnx.add(var, epsilon);
343+
Value stdDev = create.onnx.sqrt(varWithEpsilon);
344+
Value invStdDev = create.onnx.reciprocal(stdDev);
345+
Value d = create.onnx.sub(X, meanOfX);
346+
Value normalized = create.onnx.mul(d, invStdDev);
347+
Value Y = create.onnx.mul(normalized, lnOp.getScale());
348+
if (!isNoneValue(lnOp.getB()))
349+
Y = create.onnx.add(Y, lnOp.getB());
350+
llvm::SmallVector<Value, 3> outputs;
351+
outputs.emplace_back(Y);
352+
Value noneValue;
353+
if (isNoneValue(lnOp.getMean()))
354+
outputs.emplace_back(noneValue);
355+
else
356+
outputs.emplace_back(meanOfX);
357+
if (isNoneValue(lnOp.getInvStdDev()))
358+
outputs.emplace_back(noneValue);
359+
else
360+
outputs.emplace_back(invStdDev);
361+
rewriter.replaceOp(lnOp, outputs);
362+
return success();
363+
}
364+
365+
struct ONNXLayerNormalizationOpLowering
366+
: public OpConversionPattern<ONNXLayerNormalizationOp> {
367+
ONNXLayerNormalizationOpLowering(
368+
TypeConverter &typeConverter, MLIRContext *ctx, bool enableSIMD)
369+
: OpConversionPattern(typeConverter, ctx), enableSIMD(enableSIMD) {}
370+
371+
bool enableSIMD;
372+
373+
LogicalResult matchAndRewrite(ONNXLayerNormalizationOp lnOp,
374+
ONNXLayerNormalizationOpAdaptor adaptor,
375+
ConversionPatternRewriter &rewriter) const final {
376+
// Get generic info.
377+
Operation *op = lnOp.getOperation();
378+
ValueRange operands = adaptor.getOperands();
379+
Location loc = ONNXLoc<ONNXLayerNormalizationOp>(op);
380+
// Create builder and shape helper
381+
MDBuilder create(rewriter, loc);
382+
ONNXLayerNormalizationOpShapeHelper shapeHelper(
383+
op, operands, &create.krnlIE);
384+
shapeHelper.computeShapeAndAssertOnFailure();
385+
386+
return generateONNXLayerNormalizationOpONNXCode(rewriter, loc, lnOp);
387+
}
388+
};
389+
288390
void populateLoweringONNXNormalizationOpPattern(RewritePatternSet &patterns,
289-
TypeConverter &typeConverter, MLIRContext *ctx) {
391+
TypeConverter &typeConverter, MLIRContext *ctx, bool enableSIMD) {
290392
patterns.insert<ONNXBatchNormalizationInferenceModeOpLowering>(
291393
typeConverter, ctx);
292394
patterns.insert<ONNXInstanceNormalizationOpLowering>(typeConverter, ctx);
395+
patterns.insert<ONNXLayerNormalizationOpLowering>(
396+
typeConverter, ctx, enableSIMD);
293397
}
294398

295399
} // namespace onnx_mlir

src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,11 @@ void populateLoweringONNXCategoryMapperOpPattern(
341341
void populateLoweringONNXConvOpPattern(mlir::RewritePatternSet &,
342342
mlir::TypeConverter &, mlir::MLIRContext *, bool enableParallel,
343343
std::string opsForCall);
344-
void populateLoweringONNXNormalizationOpPattern(
345-
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
344+
mlir::LogicalResult generateONNXLayerNormalizationOpONNXCode(
345+
mlir::ConversionPatternRewriter &rewriter, mlir::Location loc,
346+
mlir::ONNXLayerNormalizationOp lnOp);
347+
void populateLoweringONNXNormalizationOpPattern(mlir::RewritePatternSet &,
348+
mlir::TypeConverter &, mlir::MLIRContext *, bool enableSIMD);
346349
void populateLoweringONNXPoolingOpPattern(
347350
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
348351

src/Dialect/Mlir/DialectBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1633,7 +1633,7 @@ void VectorBuilder::multiReduction(SmallVectorImpl<Value> &inputVecArray,
16331633
}
16341634
}
16351635

1636-
int64_t VectorBuilder::SuitableUnrollFactor(VectorMachineSupport *vms,
1636+
int64_t VectorBuilder::computeSuitableUnrollFactor(VectorMachineSupport *vms,
16371637
MemRefType memRefType, llvm::SmallVectorImpl<IndexExpr> &memRefDims,
16381638
int64_t collapsedInnermostLoops, int64_t maxSimdUnroll, bool canPad,
16391639
int64_t &estimatedSimdLoopTripCount) const {

src/Dialect/Mlir/DialectBuilder.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ struct VectorBuilder final : DialectBuilder {
440440
// estimatedSimdLoopTripCount: provide an estimation of the SIMD loop trip
441441
// count. If runtime, return -1; if cannot simdize, return 0; if compile time
442442
// (or a multiple of a compile time value): return that literal.
443-
int64_t SuitableUnrollFactor(VectorMachineSupport *vms,
443+
int64_t computeSuitableUnrollFactor(VectorMachineSupport *vms,
444444
mlir::MemRefType memRefType, llvm::SmallVectorImpl<IndexExpr> &memRefDims,
445445
int64_t collapsedInnermostLoops, int64_t maxSimdUnroll, bool canPad,
446446
int64_t &estimatedSimdLoopTripCount) const;

src/Dialect/ONNX/DialectBuilder.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ Value OnnxBuilder::reduceMax(Type outputType, Value data, Value axes,
196196
toTensor(data), toTensor(axes), i_keepDims, i_noop_with_empty_axes);
197197
}
198198

199+
Value OnnxBuilder::reduceMean(Type outputType, Value data, Value axes,
200+
bool keepDims, bool noop_with_empty_axes) const {
201+
int64_t i_keepDims = keepDims; // 0 if false, 1 if true
202+
int64_t i_noop_with_empty_axes = noop_with_empty_axes; // ditto
203+
return createTypedOpAndInferShapes<ONNXReduceMeanOp>(toTensor(outputType),
204+
toTensor(data), toTensor(axes), i_keepDims, i_noop_with_empty_axes);
205+
}
206+
199207
Value OnnxBuilder::reduceMin(Type outputType, Value data, Value axes,
200208
bool keepDims, bool noop_with_empty_axes) const {
201209
int64_t i_keepDims = keepDims; // 0 if false, 1 if true
@@ -212,6 +220,12 @@ Value OnnxBuilder::reduceSum(Type outputType, Value data, Value axes,
212220
toTensor(data), toTensor(axes), i_keepDims, i_noop_with_empty_axes);
213221
}
214222

223+
Value OnnxBuilder::reciprocal(Value input) const {
224+
Type outputType = input.getType(); // input == output type.
225+
return createTypedOpAndInferShapes<ONNXReciprocalOp>(
226+
toTensor(outputType), toTensor(input));
227+
}
228+
215229
Value OnnxBuilder::reshape(Type outputType, Value input, Value shape) const {
216230
return createTypedOpAndInferShapes<ONNXReshapeOp>(
217231
toTensor(outputType), toTensor(input), toTensor(shape));
@@ -261,6 +275,10 @@ Value OnnxBuilder::slice(Type outputType, Value input, int64_t start,
261275
return slice(outputType, input, startVal, endVal, /*axis*/ zeroVal, stepVal);
262276
}
263277

278+
Value OnnxBuilder::sqrt(Value input) const {
279+
return createOpAndInferShapes<ONNXSqrtOp>(toTensor(input));
280+
}
281+
264282
ValueRange OnnxBuilder::split(
265283
TypeRange outputTypes, Value input, Value split, int64_t axis) const {
266284
IntegerAttr axisAttr =

src/Dialect/ONNX/DialectBuilder.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ struct OnnxBuilder : DialectBuilder {
103103
mlir::Value axes, bool keepDims = true,
104104
bool noop_with_empty_axes = false) const;
105105

106+
// ONNXReduceMeanOp
107+
mlir::Value reduceMean(mlir::Type outputType, mlir::Value data,
108+
mlir::Value axes, bool keepDims = true,
109+
bool noop_with_empty_axes = false) const;
110+
106111
// ONNXReduceMinOp
107112
mlir::Value reduceMin(mlir::Type outputType, mlir::Value data,
108113
mlir::Value axes, bool keepDims = true,
@@ -123,6 +128,9 @@ struct OnnxBuilder : DialectBuilder {
123128
mlir::Value reshapeToNDim(
124129
mlir::Value val, int64_t N, bool collapseMostSignificant) const;
125130

131+
// ONNXReciprocalOp
132+
mlir::Value reciprocal(mlir::Value input) const;
133+
126134
// ONNXReverseSequenceOp
127135
mlir::Value reverseSequence(mlir::Type outputType, mlir::Value input,
128136
mlir::Value sequenceLens, int64_t batchAxis, int64_t timeAxis) const;
@@ -140,6 +148,9 @@ struct OnnxBuilder : DialectBuilder {
140148
mlir::Value slice(mlir::Type outputType, mlir::Value input, int64_t start,
141149
int64_t end, int64_t step = 1) const; // 1D slice
142150

151+
// ONNXSqrtOp
152+
mlir::Value sqrt(mlir::Value input) const;
153+
143154
// ONNXSplitOp
144155
mlir::ValueRange split(mlir::TypeRange outputTypes, mlir::Value input,
145156
mlir::Value split, int64_t axis) const;

0 commit comments

Comments
 (0)