|
| 1 | +//===- LowerWordToBits.cpp - Bit-Blasting Words to Bits ---------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This pass lowers multi-bit AIG operations to single-bit ones. |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "circt/Dialect/AIG/AIGOps.h" |
| 14 | +#include "circt/Dialect/AIG/AIGPasses.h" |
| 15 | +#include "circt/Dialect/Comb/CombOps.h" |
| 16 | +#include "circt/Dialect/HW/HWOps.h" |
| 17 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 18 | + |
| 19 | +#define DEBUG_TYPE "aig-lower-word-to-bits" |
| 20 | + |
| 21 | +namespace circt { |
| 22 | +namespace aig { |
| 23 | +#define GEN_PASS_DEF_LOWERWORDTOBITS |
| 24 | +#include "circt/Dialect/AIG/AIGPasses.h.inc" |
| 25 | +} // namespace aig |
| 26 | +} // namespace circt |
| 27 | + |
| 28 | +using namespace circt; |
| 29 | +using namespace aig; |
| 30 | + |
| 31 | +//===----------------------------------------------------------------------===// |
| 32 | +// Rewrite patterns |
| 33 | +//===----------------------------------------------------------------------===// |
| 34 | + |
| 35 | +namespace { |
| 36 | + |
| 37 | +struct WordRewritePattern : public OpRewritePattern<AndInverterOp> { |
| 38 | + using OpRewritePattern<AndInverterOp>::OpRewritePattern; |
| 39 | + |
| 40 | + LogicalResult matchAndRewrite(AndInverterOp op, |
| 41 | + PatternRewriter &rewriter) const override { |
| 42 | + auto width = op.getType().getIntOrFloatBitWidth(); |
| 43 | + if (width <= 1) |
| 44 | + return failure(); |
| 45 | + |
| 46 | + SmallVector<Value> results; |
| 47 | + // We iterate over the width in reverse order to match the endianness of |
| 48 | + // `comb.concat`. |
| 49 | + for (int64_t i = width - 1; i >= 0; --i) { |
| 50 | + SmallVector<Value> operands; |
| 51 | + for (auto operand : op.getOperands()) { |
| 52 | + // Reuse bits if we can extract from `comb.concat` operands. |
| 53 | + if (auto concat = operand.getDefiningOp<comb::ConcatOp>()) { |
| 54 | + // For the simplicity, we only handle the case where all the |
| 55 | + // `comb.concat` operands are single-bit. |
| 56 | + if (concat.getNumOperands() == width && |
| 57 | + llvm::all_of(concat.getOperandTypes(), [](Type type) { |
| 58 | + return type.getIntOrFloatBitWidth() == 1; |
| 59 | + })) { |
| 60 | + // Be careful with the endianness here. |
| 61 | + operands.push_back(concat.getOperand(width - i - 1)); |
| 62 | + continue; |
| 63 | + } |
| 64 | + } |
| 65 | + // Otherwise, we need to extract the bit. |
| 66 | + operands.push_back( |
| 67 | + rewriter.create<comb::ExtractOp>(op.getLoc(), operand, i, 1)); |
| 68 | + } |
| 69 | + results.push_back(rewriter.create<AndInverterOp>(op.getLoc(), operands, |
| 70 | + op.getInvertedAttr())); |
| 71 | + } |
| 72 | + |
| 73 | + rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results); |
| 74 | + return success(); |
| 75 | + } |
| 76 | +}; |
| 77 | + |
| 78 | +} // namespace |
| 79 | + |
| 80 | +//===----------------------------------------------------------------------===// |
| 81 | +// Lower Word to Bits pass |
| 82 | +//===----------------------------------------------------------------------===// |
| 83 | + |
| 84 | +namespace { |
| 85 | +struct LowerWordToBitsPass |
| 86 | + : public impl::LowerWordToBitsBase<LowerWordToBitsPass> { |
| 87 | + void runOnOperation() override; |
| 88 | +}; |
| 89 | +} // namespace |
| 90 | + |
| 91 | +void LowerWordToBitsPass::runOnOperation() { |
| 92 | + RewritePatternSet patterns(&getContext()); |
| 93 | + patterns.add<WordRewritePattern>(&getContext()); |
| 94 | + |
| 95 | + mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
| 96 | + mlir::GreedyRewriteConfig config; |
| 97 | + // Use top-down traversal to reuse bits from `comb.concat`. |
| 98 | + config.useTopDownTraversal = true; |
| 99 | + |
| 100 | + if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), frozenPatterns, |
| 101 | + config))) |
| 102 | + return signalPassFailure(); |
| 103 | +} |
0 commit comments