Skip to content

Commit

Permalink
Remove stash_type check for onnx.LayerNormalization lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Nov 22, 2024
1 parent 1b8d7e0 commit 1656027
Showing 1 changed file with 0 additions and 20 deletions.
20 changes: 0 additions & 20 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2543,27 +2543,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.s64IntegerAttr(stashType, "stash_type", 1))
return failure();

// Since the support for `stash_type` arg does not exist in
// the torch op so we just check for the stash_type to be same
// as the input dtype since that won't require us to do any
// input type conversion and hence can be supported.
auto xType = cast<Torch::ValueTensorType>(x.getType());
std::optional<int64_t> stashTypeIntTorch =
onnxDtypeIntToTorchDtypeInt(stashType);
if (!stashTypeIntTorch.has_value())
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for the given stash_type");

FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
(torch_upstream::ScalarType)stashTypeIntTorch.value());
if (failed(stashDtype))
return failure();
if (*stashDtype != xType.getOptionalDtype())
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: stash_type should be same "
"as the input dtype");

Value constEpsilon = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(epsilon));
Expand Down

0 comments on commit 1656027

Please sign in to comment.