diff --git a/include/circt/Dialect/AIG/AIGPasses.td b/include/circt/Dialect/AIG/AIGPasses.td index 2a7285ac36b8..4950aa4a3b2d 100644 --- a/include/circt/Dialect/AIG/AIGPasses.td +++ b/include/circt/Dialect/AIG/AIGPasses.td @@ -15,4 +15,9 @@ def LowerVariadic : Pass<"aig-lower-variadic", "hw::HWModuleOp"> { let summary = "Lower variadic AndInverter operations to binary AndInverter"; } +def LowerWordToBits : Pass<"aig-lower-word-to-bits", "hw::HWModuleOp"> { + let summary = "Lower multi-bit AndInverter to single-bit ones"; + let dependentDialects = ["comb::CombDialect"]; +} + #endif // CIRCT_DIALECT_AIG_AIGPASSES_TD diff --git a/lib/Dialect/AIG/CMakeLists.txt b/lib/Dialect/AIG/CMakeLists.txt index 4a31ca574d0c..16c312ec2240 100644 --- a/lib/Dialect/AIG/CMakeLists.txt +++ b/lib/Dialect/AIG/CMakeLists.txt @@ -13,4 +13,4 @@ add_circt_dialect_library(CIRCTAIG MLIRAIGIncGen ) -add_subdirectory(Transforms) \ No newline at end of file +add_subdirectory(Transforms) diff --git a/lib/Dialect/AIG/Transforms/CMakeLists.txt b/lib/Dialect/AIG/Transforms/CMakeLists.txt index a6886bf7d888..045734e122b6 100644 --- a/lib/Dialect/AIG/Transforms/CMakeLists.txt +++ b/lib/Dialect/AIG/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_circt_dialect_library(CIRCTAIGTransforms LowerVariadic.cpp + LowerWordToBits.cpp DEPENDS CIRCTAIGPassesIncGen diff --git a/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp b/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp new file mode 100644 index 000000000000..27cee5ebe7be --- /dev/null +++ b/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp @@ -0,0 +1,103 @@ +//===- LowerWordToBits.cpp - Bit-Blasting Words to Bits ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers multi-bit AIG operations to single-bit ones. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/AIG/AIGPasses.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "aig-lower-word-to-bits" + +namespace circt { +namespace aig { +#define GEN_PASS_DEF_LOWERWORDTOBITS +#include "circt/Dialect/AIG/AIGPasses.h.inc" +} // namespace aig +} // namespace circt + +using namespace circt; +using namespace aig; + +//===----------------------------------------------------------------------===// +// Rewrite patterns +//===----------------------------------------------------------------------===// + +namespace { + +struct WordRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AndInverterOp op, + PatternRewriter &rewriter) const override { + auto width = op.getType().getIntOrFloatBitWidth(); + if (width <= 1) + return failure(); + + SmallVector results; + // We iterate over the width in reverse order to match the endianness of + // `comb.concat`. + for (int64_t i = width - 1; i >= 0; --i) { + SmallVector operands; + for (auto operand : op.getOperands()) { + // Reuse bits if we can extract from `comb.concat` operands. + if (auto concat = operand.getDefiningOp()) { + // For the simplicity, we only handle the case where all the + // `comb.concat` operands are single-bit. + if (concat.getNumOperands() == width && + llvm::all_of(concat.getOperandTypes(), [](Type type) { + return type.getIntOrFloatBitWidth() == 1; + })) { + // Be careful with the endianness here. + operands.push_back(concat.getOperand(width - i - 1)); + continue; + } + } + // Otherwise, we need to extract the bit. + operands.push_back( + rewriter.create(op.getLoc(), operand, i, 1)); + } + results.push_back(rewriter.create(op.getLoc(), operands, + op.getInvertedAttr())); + } + + rewriter.replaceOpWithNewOp(op, results); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Lower Word to Bits pass +//===----------------------------------------------------------------------===// + +namespace { +struct LowerWordToBitsPass + : public impl::LowerWordToBitsBase { + void runOnOperation() override; +}; +} // namespace + +void LowerWordToBitsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + mlir::GreedyRewriteConfig config; + // Use top-down traversal to reuse bits from `comb.concat`. + config.useTopDownTraversal = true; + + if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), frozenPatterns, + config))) + return signalPassFailure(); +} diff --git a/test/Dialect/AIG/lower-word-to-bits.mlir b/test/Dialect/AIG/lower-word-to-bits.mlir new file mode 100644 index 000000000000..3b7654f2bcc9 --- /dev/null +++ b/test/Dialect/AIG/lower-word-to-bits.mlir @@ -0,0 +1,16 @@ +// RUN: circt-opt %s --aig-lower-word-to-bits | FileCheck %s +// CHECK: hw.module @Basic +hw.module @Basic(in %a: i2, in %b: i2, out f: i2) { + %0 = aig.and_inv not %a, %b : i2 + %1 = aig.and_inv not %0, not %0 : i2 + // CHECK-NEXT: %[[EXTRACT_A_1:.+]] = comb.extract %a from 1 + // CHECK-NEXT: %[[EXTRACT_B_1:.+]] = comb.extract %b from 1 + // CHECK-NEXT: %[[AND_INV_0:.+]] = aig.and_inv not %[[EXTRACT_A_1]], %[[EXTRACT_B_1]] + // CHECK-NEXT: %[[EXTRACT_A_0:.+]] = comb.extract %a from 0 + // CHECK-NEXT: %[[EXTRACT_B_0:.+]] = comb.extract %b from 0 + // CHECK-NEXT: %[[AND_INV_1:.+]] = aig.and_inv not %[[EXTRACT_A_0]], %[[EXTRACT_B_0]] + // CHECK-NEXT: %[[AND_INV_2:.+]] = aig.and_inv not %[[AND_INV_0]], not %[[AND_INV_0]] + // CHECK-NEXT: %[[AND_INV_3:.+]] = aig.and_inv not %[[AND_INV_1]], not %[[AND_INV_1]] + // CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[AND_INV_2]], %[[AND_INV_3]] + hw.output %1 : i2 +}