Skip to content

Commit

Permalink
Support stash_type attribute for onnx.LayerNormalization
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Nov 24, 2024
1 parent 878f992 commit 367e6d1
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 23 deletions.
81 changes: 59 additions & 22 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2543,7 +2543,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.s64IntegerAttr(stashType, "stash_type", 1))
return failure();

std::optional<int64_t> stashTypeIntTorch =
onnxDtypeIntToTorchDtypeInt(stashType);
if (!stashTypeIntTorch.has_value())
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for the given stash_type");
FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
(torch_upstream::ScalarType)stashTypeIntTorch.value());
if (failed(stashDtype))
return failure();

// Convert dtype if stash_type is different from input dtype
auto xType = cast<Torch::ValueTensorType>(x.getType());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
if (*stashDtype != xType.getOptionalDtype()) {
auto newXType =
xType.getWithSizesAndDtype(xType.getOptionalSizes(), *stashDtype);
Value dtypeValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(stashTypeIntTorch.value()));
x = rewriter.create<Torch::AtenToDtypeOp>(
binder.getLoc(), newXType, x, /*dtype=*/dtypeValue,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}

Value constEpsilon = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(epsilon));
Expand All @@ -2566,33 +2593,43 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
normalized);

SmallVector<int64_t> reducedShape(rank, 1);
for (int64_t i = 0; i < axis; i++)
reducedShape[i] = xShape[i];
auto reducedType =
xType.getWithSizesAndDtype(reducedShape, *stashDtype);
auto y = rewriter.create<Torch::AtenNativeLayerNormOp>(
binder.getLoc(), yType, /*meanType=*/reducedType,
/*invStdDevType=*/reducedType, x, normalized_shape,
scale, b, constEpsilon);

int64_t numResults = binder.op->getNumResults();
if (numResults == 1) {
SmallVector<int64_t> reducedShape(rank, 1);
for (int64_t i = 0; i < axis; i++)
reducedShape[i] = xShape[i];
auto reducedType = xType.getWithSizesAndDtype(
reducedShape, xType.getOptionalDtype());
Value y = rewriter
.create<Torch::AtenNativeLayerNormOp>(
binder.getLoc(), yType, /*meanType=*/reducedType,
/*invStdDevType=*/reducedType, x, normalized_shape,
scale, b, constEpsilon)
.getResult0();
rewriter.replaceOp(binder.op, y);
rewriter.replaceOp(binder.op, y.getResult0());
return success();
}
if (numResults == 3) {
if (binder.tensorResultTypeAtIndex(meanType, 1) ||
binder.tensorResultTypeAtIndex(invStdDevType, 2))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenNativeLayerNormOp>(
binder.op, yType, meanType, invStdDevType, x, normalized_shape,
scale, b, constEpsilon);
return success();

Value meanOutput = y.getResult1();
Value varOutput = y.getResult2();
// Convert meanType and varType back if stash_dtype is different
if (binder.tensorResultTypeAtIndex(meanType, 1) ||
binder.tensorResultTypeAtIndex(invStdDevType, 2))
return failure();
if (*stashDtype != meanType.getOptionalDtype()) {
Value constDtype = Torch::getDtypeIntValueForType(
rewriter, binder.getLoc(), meanType.getDtype());
meanOutput = rewriter.create<Torch::AtenToDtypeOp>(
binder.getLoc(), meanType, meanOutput, /*dtype=*/constDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
varOutput = rewriter.create<Torch::AtenToDtypeOp>(
binder.getLoc(), invStdDevType, varOutput, /*dtype=*/constDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
return rewriter.notifyMatchFailure(
binder.op, "Unimplemented: expected either 1 or 3 results");
rewriter.replaceOp(binder.op, {y.getResult0(), meanOutput, varOutput});

return success();
});
patterns.onOp("LeakyRelu", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down
14 changes: 13 additions & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6635,7 +6635,7 @@ class DecomposeAtenNativeLayerNormOp
Location loc = op.getLoc();
auto context = op.getContext();

auto inputTy = cast<BaseTensorType>(op.getInput().getType());
auto inputTy = cast<ValueTensorType>(op.getInput().getType());
if (!inputTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
Expand Down Expand Up @@ -6690,6 +6690,18 @@ class DecomposeAtenNativeLayerNormOp
loc, inputTy, inputRsqrtVar, op.getInput());
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
// Convert resultType if dtype is different
auto resultTensorType =
dyn_cast<ValueTensorType>(op.getResult(0).getType());
if (inputTy.getDtype() != resultTensorType.getDtype()) {
Value dtypeValue = Torch::getDtypeIntValueForType(
rewriter, loc, resultTensorType.getDtype());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
inputNormalized = rewriter.create<Torch::AtenToDtypeOp>(
loc, resultTensorType, inputNormalized,
/*dtype=*/dtypeValue, /*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
Value out = rewriter.create<TensorStaticInfoCastOp>(
loc, op.getResult(0).getType(), inputNormalized);

Expand Down

0 comments on commit 367e6d1

Please sign in to comment.