-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Symbolic shape expressions in TCP dialect (#78)
Support for TorchDynamo captured symbolic shape expressions was added to Torch-MLIR recently (llvm/torch-mlir#3372). This PR continues the work to lower these to TCP. - [x] TCP op definitions, custom printer/parser/verifier - [x] Dialect and conversion lit tests - [x] Custom op python lit tests - [x] Cleanup pass to remove bind ops
- Loading branch information
1 parent
0bfd510
commit e6eaf43
Showing
15 changed files
with
508 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
22 changes: 22 additions & 0 deletions
22
include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include <memory> | ||
|
||
namespace mlir::tcp { | ||
|
||
std::unique_ptr<mlir::OperationPass<func::FuncOp>> | ||
createDropSymbolicShapeOpsPass(); | ||
|
||
} // namespace mlir::tcp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h" | ||
|
||
#include "mlir-tcp/Dialect/IR/TcpDialect.h" | ||
#include "mlir-tcp/Dialect/IR/TcpOps.h" | ||
#include "mlir-tcp/Dialect/Transforms/Passes.h" | ||
|
||
#include "./PassDetail.h" | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace mlir::tcp { | ||
|
||
namespace { | ||
|
||
class RemoveBindSymbolicShapeOps | ||
: public OpRewritePattern<tcp::BindSymbolicShapeOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tcp::BindSymbolicShapeOp op, | ||
PatternRewriter &rewriter) const override { | ||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
}; | ||
|
||
class DropSymbolicShapeOpsPass | ||
: public DropSymbolicShapeOpsBase<DropSymbolicShapeOpsPass> { | ||
void runOnOperation() override { | ||
Operation *op = getOperation(); | ||
MLIRContext *context = op->getContext(); | ||
RewritePatternSet patterns(context); | ||
|
||
patterns.add<RemoveBindSymbolicShapeOps>(context); | ||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> createDropSymbolicShapeOpsPass() { | ||
return std::make_unique<DropSymbolicShapeOpsPass>(); | ||
} | ||
|
||
} // namespace mlir::tcp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.