From 75b208bcfd65825d3b8f2d51949e6cb35680a400 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Sun, 24 Nov 2024 02:11:57 -0800 Subject: [PATCH] Support stash_type attribute for onnx.LayerNormalization --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 81 ++++++++++++++----- .../Torch/Transforms/DecomposeComplexOps.cpp | 14 +++- 2 files changed, 72 insertions(+), 23 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index b73f1100dcb7..77ff7495453d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -2543,7 +2543,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.s64IntegerAttr(stashType, "stash_type", 1)) return failure(); + std::optional stashTypeIntTorch = + onnxDtypeIntToTorchDtypeInt(stashType); + if (!stashTypeIntTorch.has_value()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for the given stash_type"); + FailureOr stashDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + (torch_upstream::ScalarType)stashTypeIntTorch.value()); + if (failed(stashDtype)) + return failure(); + + // Convert dtype if stash_type is different from input dtype auto xType = cast(x.getType()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value none = rewriter.create(binder.getLoc()); + if (*stashDtype != xType.getOptionalDtype()) { + auto newXType = + xType.getWithSizesAndDtype(xType.getOptionalSizes(), *stashDtype); + Value dtypeValue = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(stashTypeIntTorch.value())); + x = rewriter.create( + binder.getLoc(), newXType, x, /*dtype=*/dtypeValue, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } + Value constEpsilon = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(epsilon)); @@ -2566,33 +2593,43 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), normalized); + SmallVector reducedShape(rank, 1); + for (int64_t i = 0; i < axis; i++) + reducedShape[i] = xShape[i]; + auto reducedType = + xType.getWithSizesAndDtype(reducedShape, *stashDtype); + auto y = rewriter.create( + binder.getLoc(), yType, /*meanType=*/reducedType, + /*invStdDevType=*/reducedType, x, normalized_shape, scale, b, + constEpsilon); + int64_t numResults = binder.op->getNumResults(); if (numResults == 1) { - SmallVector reducedShape(rank, 1); - for (int64_t i = 0; i < axis; i++) - reducedShape[i] = xShape[i]; - auto reducedType = xType.getWithSizesAndDtype( - reducedShape, xType.getOptionalDtype()); - Value y = rewriter - .create( - binder.getLoc(), yType, /*meanType=*/reducedType, - /*invStdDevType=*/reducedType, x, normalized_shape, - scale, b, constEpsilon) - .getResult0(); - rewriter.replaceOp(binder.op, y); + rewriter.replaceOp(binder.op, y.getResult0()); return success(); } - if (numResults == 3) { - if (binder.tensorResultTypeAtIndex(meanType, 1) || - binder.tensorResultTypeAtIndex(invStdDevType, 2)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, yType, meanType, invStdDevType, x, normalized_shape, - scale, b, constEpsilon); - return success(); + + Value meanOutput = y.getResult1(); + Value varOutput = y.getResult2(); + // Convert meanType and varType back if stash_dtype is different + if (binder.tensorResultTypeAtIndex(meanType, 1) || + binder.tensorResultTypeAtIndex(invStdDevType, 2)) + return failure(); + if (*stashDtype != meanType.getOptionalDtype()) { + Value constDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), meanType.getDtype()); + meanOutput = rewriter.create( + binder.getLoc(), meanType, meanOutput, /*dtype=*/constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + varOutput = rewriter.create( + binder.getLoc(), invStdDevType, varOutput, /*dtype=*/constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); } - return rewriter.notifyMatchFailure( - binder.op, "Unimplemented: expected either 1 or 3 results"); + rewriter.replaceOp(binder.op, {y.getResult0(), meanOutput, varOutput}); + + return success(); }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6207e753ea4f..9ada951c699f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -6635,7 +6635,7 @@ class DecomposeAtenNativeLayerNormOp Location loc = op.getLoc(); auto context = op.getContext(); - auto inputTy = cast(op.getInput().getType()); + auto inputTy = cast(op.getInput().getType()); if (!inputTy.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); @@ -6690,6 +6690,18 @@ class DecomposeAtenNativeLayerNormOp loc, inputTy, inputRsqrtVar, op.getInput()); Value inputNormalized = rewriter.create( loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); + // Convert resultType if dtype is different + auto resultTensorType = + dyn_cast(op.getResult(0).getType()); + if (inputTy.getDtype() != resultTensorType.getDtype()) { + Value dtypeValue = Torch::getDtypeIntValueForType( + rewriter, loc, resultTensorType.getDtype()); + Value cstFalse = rewriter.create(loc, false); + inputNormalized = rewriter.create( + loc, resultTensorType, inputNormalized, + /*dtype=*/dtypeValue, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } Value out = rewriter.create( loc, op.getResult(0).getType(), inputNormalized);