diff --git a/docs/Dialects/AIG/RationaleAIG.md b/docs/Dialects/AIG/RationaleAIG.md new file mode 100644 index 000000000000..04fc9544458f --- /dev/null +++ b/docs/Dialects/AIG/RationaleAIG.md @@ -0,0 +1,24 @@ +# `aig` Dialect Rationale + +This document outlines the rationale of the AIG dialect, a dialect used for representing and transforming And-Inverter Graphs. + +# Why use the AIG dialect instead of the `comb` dialect? + +And-Inverter Graphs have proven to be a scalable approach for logic synthesis, serving as the underlying data structure for ABC, one of the most performant open-source logic synthesis tools. + +While it's technically possible to represent `aig.and_inv` using a combination of `comb.and`, `comb.xor`, and `hw.constant`, the ability to represent everything with `aig.and_inv` offers significant advantages. This unified representation simplifies complex analyses such as path retiming and area analysis, as well as logic mappings. Moreover, it allows for direct application of existing AIG research results and tools, further enhancing its utility in the synthesis process. + +# Operations +## aig.and_inv + +The `aig.and_inv` operation directly represents an And-Node in the AIG. Traditionally, an And-Node in AIG has two operands. However, `aig.and_inv` extends this concept by allowing variadic operands and non-i1 integer types. Although the final stage of the pipeline requires lowering everything to i1-binary operands, it's more efficient to progressively lower the variadic multibit operations. + +Variadic operands have demonstrated their utility in low-level optimizations within the `comb` dialect. Furthermore, in synthesis, it's common practice to re-balance the logic path. Variadic operands enable the compiler to select more efficient solutions without the need to traverse binary trees multiple times. + +The ability to represent multibit operations during synthesis is crucial for scalable logic optimization. This approach enables a form of vectorization, allowing for batch processing of logic synthesis when multibit operations are constructed in a similar manner. Such vectorization can significantly improve the efficiency and performance of the synthesis process. + +## aig.cut + +The `aig.cut` operation represents a "cut" in the logic tree. This operation possesses the `IsolatedAbove` trait and contains a single block. Its input operands represent the input edges, while the returned value represents the output edges. + +This operation proves particularly useful for progressive LUT mapping. For instance, a k-input cut can be readily mapped to a k-input LUT. Consequently, the subsequent stages of the pipeline can concentrate on replacing combinational logic with k-input Cut operations, simplifying the overall process. diff --git a/include/circt/Conversion/CombToAIG.h b/include/circt/Conversion/CombToAIG.h new file mode 100644 index 000000000000..27d529a39dc2 --- /dev/null +++ b/include/circt/Conversion/CombToAIG.h @@ -0,0 +1,25 @@ +//===- CombToAIG.h - Comb to AIG dialect conversion ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_CONVERSION_COMBTOAIG_H +#define CIRCT_CONVERSION_COMBTOAIG_H + +#include "circt/Support/LLVM.h" +#include + +namespace circt { +namespace hw { +class HWModuleOp; +} + +#define GEN_PASS_DECL_CONVERTCOMBTOAIG +#include "circt/Conversion/Passes.h.inc" + +} // namespace circt + +#endif // CIRCT_CONVERSION_COMBTOARITH_H diff --git a/include/circt/Conversion/Passes.h b/include/circt/Conversion/Passes.h index 9d6b2d2aacc2..887da277f8aa 100644 --- a/include/circt/Conversion/Passes.h +++ b/include/circt/Conversion/Passes.h @@ -19,6 +19,7 @@ #include "circt/Conversion/CalyxNative.h" #include "circt/Conversion/CalyxToFSM.h" #include "circt/Conversion/CalyxToHW.h" +#include "circt/Conversion/CombToAIG.h" #include "circt/Conversion/CombToArith.h" #include "circt/Conversion/CombToSMT.h" #include "circt/Conversion/ConvertToArcs.h" diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index 791c998bfd32..e1d13bb191c6 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -790,4 +790,16 @@ def LowerSimToSV: Pass<"lower-sim-to-sv", "mlir::ModuleOp"> { ]; } +//===----------------------------------------------------------------------===// +// ConvertCombToAIG +//===----------------------------------------------------------------------===// + +def ConvertCombToAIG: Pass<"convert-comb-to-aig", "hw::HWModuleOp"> { + let summary = "Lower comb ops to aig ops."; + let dependentDialects = [ + "circt::comb::CombDialect", + "circt::aig::AIGDialect", + ]; +} + #endif // CIRCT_CONVERSION_PASSES_TD diff --git a/include/circt/Dialect/AIG/AIG.td b/include/circt/Dialect/AIG/AIG.td new file mode 100644 index 000000000000..62cb778b5691 --- /dev/null +++ b/include/circt/Dialect/AIG/AIG.td @@ -0,0 +1,25 @@ +//===- AIG.td - AIG Definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_AIG_DIALECT_TD +#define CIRCT_AIG_DIALECT_TD + +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" + +def AIG_Dialect : Dialect { + let name = "aig"; + let cppNamespace = "::circt::aig"; + let summary = "Representation of AIGs"; + + let usePropertiesForAttributes = 0; +} + +include "circt/Dialect/AIG/AIGOps.td" + +#endif // CIRCT_AIG_DIALECT_TD diff --git a/include/circt/Dialect/AIG/AIGDialect.h b/include/circt/Dialect/AIG/AIGDialect.h new file mode 100644 index 000000000000..acd0f712ae7a --- /dev/null +++ b/include/circt/Dialect/AIG/AIGDialect.h @@ -0,0 +1,22 @@ +//===- AIGDialect.h - AIG Definitions --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the AIG CIRCT dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_AIG_AIGDIALECT_H +#define CIRCT_DIALECT_AIG_AIGDIALECT_H + +#include "circt/Support/LLVM.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +#include "circt/Dialect/AIG/AIGDialect.h.inc" + +#endif // CIRCT_DIALECT_AIG_AIGDIALECT_H diff --git a/include/circt/Dialect/AIG/AIGOps.h b/include/circt/Dialect/AIG/AIGOps.h new file mode 100644 index 000000000000..ff36bd69919c --- /dev/null +++ b/include/circt/Dialect/AIG/AIGOps.h @@ -0,0 +1,31 @@ +//===- AIGOps.h - AIG Op Definitions ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_AIG_AIGOPS_H +#define CIRCT_DIALECT_AIG_AIGOPS_H + +#include "circt/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "circt/Dialect/AIG/AIGDialect.h" + +#define GET_OP_CLASSES +#include "circt/Dialect/AIG/AIG.h.inc" + +#endif // CIRCT_DIALECT_AIG_AIGOPS_H diff --git a/include/circt/Dialect/AIG/AIGOps.td b/include/circt/Dialect/AIG/AIGOps.td new file mode 100644 index 000000000000..ca6e99da8f3d --- /dev/null +++ b/include/circt/Dialect/AIG/AIGOps.td @@ -0,0 +1,120 @@ +//===- AIGOps.td - AIG Op Definitions -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// AIG Ops are defined in tablegen. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_AIG_OPS_TD +#define CIRCT_DIALECT_AIG_OPS_TD + +include "circt/Dialect/AIG/AIG.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + + +class AIGOp traits = []> : + Op; + +def AndInverterOp : AIGOp<"and_inv", [SameOperandsAndResultType, Pure]> { + let summary = "AIG dialect AND operation"; + let description = [{ + The `aig.and_inv` operation represents an And-Inverter in the AIG dialect. + Unlike comb.and, operands can be inverted respectively. + + Example: + ```mlir + %r1 = aig.and_inv %a, %b: i3 + %r2 = aig.and_inv not %a, %b, not %c : i3 + %r3 = aig.and_inv not %a : i3 + ``` + }]; + // TODO: Restrict to HWIntegerType. + let arguments = (ins Variadic:$inputs, DenseBoolArrayAttr:$inverted); + let results = (outs AnyType:$result); + + // NOTE: Custom assembly format is needed to pretty print the `inverted` + // attribute. + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins "Value":$input, CArg<"bool", "false">:$invert), [{ + SmallVector inverted {invert}; + return build($_builder, $_state, {input}, inverted); + }]>, + OpBuilder<(ins "Value":$lhs, "Value":$rhs, CArg<"bool", "false">:$invertLhs, + CArg<"bool", "false">:$invertRhs), [{ + SmallVector inverted {invertLhs, invertRhs}; + return build($_builder, $_state, {lhs, rhs}, inverted); + }]> + ]; + + let extraClassDeclaration = [{ + // Evaluate the operation with the given input values. + APInt evaluate(ArrayRef inputs); + + // Check if the input is inverted. + bool isInverted(size_t idx) { + return getInverted()[idx]; + } + }]; + let hasFolder = 1; + let hasCanonicalizeMethod = 1; +} + +def CutOp : AIGOp<"cut", [IsolatedFromAbove, SingleBlock]> { + let summary = "AIG dialect Cut operation"; + let description = [{ + The `aig.cut` operation represents a cut in the And-Inverter-Graph. + This operation is variadic and can take multiple inputs and outputs, + which corresponds to the input and output edges in AIG conceptually. + + ```mlir + %0, %1 = aig.cut %a, %b, %c, %d : (i1, i1, i1, i1) -> (i1, i1) { + ^bb0(%arg0: i1, %arg1: i1, %arg2: i1, %arg3: i1): + %0 = aig.and_inv not %arg0, %arg1 : i1 + %1 = aig.and_inv %arg1, %arg3 : i1 + aig.output %0, %1 : i1 + } + ``` + + }]; + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = [{ + $inputs attr-dict `:` functional-type($inputs, $results) $body + }]; + + let builders = [ + OpBuilder<(ins + CArg<"TypeRange", "{}">:$resultTypes, + CArg<"ValueRange", "{}">:$inputs, + CArg<"std::function", "{}">:$ctor)> + ]; + + let extraClassDeclaration = [{ + Block *getBodyBlock() { return &getBody().front(); } + }]; +} + +def OutputOp : AIGOp<"output", [Terminator, + ReturnLike, ParentOneOf<["CutOp"]>]> { + let summary = "AIG dialect Output operation"; + let description = [{ + The `aig.output` operation represents out edges of a cut. + }]; + let arguments = (ins Variadic:$outputs); + let assemblyFormat = [{ + attr-dict ($outputs^ `:` qualified(type($outputs)))? + }]; +} + +#endif // CIRCT_DIALECT_AIG_OPS_TD diff --git a/include/circt/Dialect/AIG/AIGPasses.h b/include/circt/Dialect/AIG/AIGPasses.h new file mode 100644 index 000000000000..4226bc35711b --- /dev/null +++ b/include/circt/Dialect/AIG/AIGPasses.h @@ -0,0 +1,40 @@ +//===- AIGPasses.h - AIG dialect passes -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_AIG_AIGPASSES_H +#define CIRCT_DIALECT_AIG_AIGPASSES_H + +#include "mlir/Pass/Pass.h" +#include +#include + +namespace mlir { +class Pass; +} // namespace mlir + +#include "circt/Dialect/AIG/AIGPassesEnums.h.inc" + +namespace circt { +namespace aig { + +#define GEN_PASS_DECL +#include "circt/Dialect/AIG/AIGPasses.h.inc" + +std::unique_ptr createLowerCutToLUTPass(); +std::unique_ptr createLowerVariadicPass(); +std::unique_ptr createLowerWordToBitsPass(); +std::unique_ptr +createGreedyCutDecompPass(const GreedyCutDecompOptions &options = {}); + +#define GEN_PASS_REGISTRATION +#include "circt/Dialect/AIG/AIGPasses.h.inc" + +} // namespace aig +} // namespace circt + +#endif // CIRCT_DIALECT_AIG_AIGPASSES_H diff --git a/include/circt/Dialect/AIG/AIGPasses.td b/include/circt/Dialect/AIG/AIGPasses.td new file mode 100644 index 000000000000..8c570245c17f --- /dev/null +++ b/include/circt/Dialect/AIG/AIGPasses.td @@ -0,0 +1,41 @@ +//===- ArcPasses.td - Arc dialect passes -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_ARC_ARCPASSES_TD +#define CIRCT_DIALECT_ARC_ARCPASSES_TD + +include "mlir/Pass/PassBase.td" + +def LowerCutToLUT : Pass<"aig-lower-cut-to-lut", "hw::HWModuleOp"> { + let summary = "Lower a cut to a LUT"; + let dependentDialects = ["comb::CombDialect"]; + let constructor = "circt::aig::createLowerCutToLUTPass()"; +} + +def LowerVariadic : Pass<"aig-lower-variadic", "hw::HWModuleOp"> { + let summary = "Lower variadic AndInverter operations to binary AndInverter"; + let constructor = "circt::aig::createLowerVariadicPass()"; +} + +def LowerWordToBits : Pass<"aig-lower-word-to-bits", "hw::HWModuleOp"> { + let summary = "Lower multi-bit AIG operations to single-bit ones"; + let dependentDialects = ["comb::CombDialect"]; + let constructor = "circt::aig::createLowerWordToBitsPass()"; +} + +def GreedyCutDecomp : Pass<"aig-greedy-cut-decomp", "hw::HWModuleOp"> { + let summary = "Decompose AIGs into k-feasible Cuts using a greedy algorithm"; + let dependentDialects = ["comb::CombDialect"]; + let constructor = "circt::aig::createGreedyCutDecompPass()"; + let options = [ + Option<"cutSizes", "cut-sizes", "uint32_t", "6", + "The sizes of the cuts to decompose">, + ]; +} + +#endif // CIRCT_DIALECT_ARC_ARCPASSES_TD diff --git a/include/circt/Dialect/AIG/CMakeLists.txt b/include/circt/Dialect/AIG/CMakeLists.txt new file mode 100644 index 000000000000..3bd74131f1c2 --- /dev/null +++ b/include/circt/Dialect/AIG/CMakeLists.txt @@ -0,0 +1,9 @@ +add_circt_dialect(AIG aig) +add_circt_dialect_doc(AIG aig) + +set(LLVM_TARGET_DEFINITIONS AIGPasses.td) +mlir_tablegen(AIGPasses.h.inc -gen-pass-decls) +mlir_tablegen(AIGPassesEnums.h.inc -gen-enum-decls) +mlir_tablegen(AIGPassesEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(CIRCTAIGTransformsIncGen) +add_circt_doc(AIGPasses AIGPasses -gen-pass-doc) diff --git a/include/circt/Dialect/CMakeLists.txt b/include/circt/Dialect/CMakeLists.txt index 96a7b1013375..798409471773 100644 --- a/include/circt/Dialect/CMakeLists.txt +++ b/include/circt/Dialect/CMakeLists.txt @@ -6,6 +6,7 @@ ## ##===----------------------------------------------------------------------===// +add_subdirectory(AIG) add_subdirectory(Arc) add_subdirectory(Calyx) add_subdirectory(Comb) diff --git a/include/circt/InitAllDialects.h b/include/circt/InitAllDialects.h index 763e63293918..3212a01c154d 100644 --- a/include/circt/InitAllDialects.h +++ b/include/circt/InitAllDialects.h @@ -14,6 +14,7 @@ #ifndef CIRCT_INITALLDIALECTS_H_ #define CIRCT_INITALLDIALECTS_H_ +#include "circt/Dialect/AIG/AIGDialect.h" #include "circt/Dialect/Arc/ArcDialect.h" #include "circt/Dialect/Calyx/CalyxDialect.h" #include "circt/Dialect/Comb/CombDialect.h" @@ -51,6 +52,7 @@ namespace circt { inline void registerAllDialects(mlir::DialectRegistry ®istry) { // clang-format off registry.insert< + aig::AIGDialect, arc::ArcDialect, calyx::CalyxDialect, chirrtl::CHIRRTLDialect, diff --git a/include/circt/InitAllPasses.h b/include/circt/InitAllPasses.h index 173e3adafbb6..36ae439f2970 100644 --- a/include/circt/InitAllPasses.h +++ b/include/circt/InitAllPasses.h @@ -16,6 +16,7 @@ #include "circt/Conversion/ExportVerilog.h" #include "circt/Conversion/Passes.h" +#include "circt/Dialect/AIG/AIGPasses.h" #include "circt/Dialect/Arc/ArcPasses.h" #include "circt/Dialect/Calyx/CalyxPasses.h" #include "circt/Dialect/Comb/CombPasses.h" @@ -57,6 +58,7 @@ inline void registerAllPasses() { registerBMCTransformsPasses(); // Standard Passes + aig::registerPasses(); arc::registerPasses(); calyx::registerPasses(); comb::registerPasses(); diff --git a/lib/CAPI/Conversion/CMakeLists.txt b/lib/CAPI/Conversion/CMakeLists.txt index 385ede3bef18..8f6176f027a6 100644 --- a/lib/CAPI/Conversion/CMakeLists.txt +++ b/lib/CAPI/Conversion/CMakeLists.txt @@ -7,6 +7,7 @@ add_circt_public_c_api_library(CIRCTCAPIConversion CIRCTCalyxToFSM CIRCTCalyxToHW CIRCTCalyxNative + CIRCTCombToAIG CIRCTCombToArith CIRCTCombToLLVM CIRCTCombToSMT diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 04bc5bfa9d71..bcb890e5c900 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(CalyxToHW) add_subdirectory(CombToArith) add_subdirectory(CombToLLVM) add_subdirectory(CombToSMT) +add_subdirectory(CombToAIG) add_subdirectory(ConvertToArcs) add_subdirectory(DCToHW) add_subdirectory(ExportChiselInterface) diff --git a/lib/Conversion/CombToAIG/CMakeLists.txt b/lib/Conversion/CombToAIG/CMakeLists.txt new file mode 100644 index 000000000000..38d6832bd55c --- /dev/null +++ b/lib/Conversion/CombToAIG/CMakeLists.txt @@ -0,0 +1,20 @@ +add_circt_conversion_library(CIRCTCombToAIG + CombToAIG.cpp + + ADDITIONAL_HEADER_DIRS + ${CIRCT_MAIN_INCLUDE_DIR}/circt/Conversion/CombToAIG + + DEPENDS + CIRCTConversionPassIncGen + + LINK_LIBS PUBLIC + CIRCTHW + CIRCTComb + CIRCTAIG + CIRCTSupport + CIRCTTransforms + MLIRIR + MLIRPass + MLIRSupport + MLIRTransforms +) diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp new file mode 100644 index 000000000000..6d95c831d544 --- /dev/null +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -0,0 +1,117 @@ +//===- CombToAIG.cpp - Comb to AIG Conversion Pass ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Conversion/CombToAIG.h" +#include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace circt { +#define GEN_PASS_DEF_CONVERTCOMBTOAIG +#include "circt/Conversion/Passes.h.inc" +} // namespace circt + +using namespace circt; +using namespace comb; + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Lower a comb::AndOp operation to aig::AndInverterOp +struct CombAndOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AndOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector nonInverts(adaptor.getInputs().size(), false); + rewriter.replaceOpWithNewOp(op, adaptor.getInputs(), + nonInverts); + return success(); + } +}; + +/// Lower a comb::OrOp operation to aig::AndInverterOp with invert flags +struct CombOrOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(OrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Implement Or using And and invert flags: a | b = ~(~a & ~b) + SmallVector allInverts(adaptor.getInputs().size(), true); + auto andOp = rewriter.create( + op.getLoc(), adaptor.getInputs(), allInverts); + rewriter.replaceOpWithNewOp(op, andOp, + /*invert=*/true); + return success(); + } +}; + +/// Lower a comb::XorOp operation to AIG operations +struct CombXorOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(XorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b) + // (a | b) = ~(~a & ~b) + // (~a | ~b) = ~(a & b) + auto inputs = adaptor.getInputs(); + SmallVector allInverts(inputs.size(), true); + SmallVector allNotInverts(inputs.size(), false); + + // a | b = ~(~a & ~b) + auto notAAndNotB = + rewriter.create(op.getLoc(), inputs, allInverts); + auto aAndB = + rewriter.create(op.getLoc(), inputs, allNotInverts); + + rewriter.replaceOpWithNewOp(op, notAAndNotB, aAndB, + /*lhs_invert=*/true, + /*rhs_invert=*/true); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Convert Comb to AIG pass +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertCombToAIGPass + : public impl::ConvertCombToAIGBase { + void runOnOperation() override; +}; +} // namespace + +static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void ConvertCombToAIGPass::runOnOperation() { + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + + RewritePatternSet patterns(&getContext()); + populateCombToAIGConversionPatterns(patterns); + + if (failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); +} diff --git a/lib/Dialect/AIG/AIGDialect.cpp b/lib/Dialect/AIG/AIGDialect.cpp new file mode 100644 index 000000000000..0acb1e1b8b2a --- /dev/null +++ b/lib/Dialect/AIG/AIGDialect.cpp @@ -0,0 +1,35 @@ +//===- AIGDialect.cpp - Implement the AIG dialect -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AIG dialect. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/AIG/AIGDialect.h" +#include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" + +using namespace circt; +using namespace aig; + +//===----------------------------------------------------------------------===// +// Dialect specification. +//===----------------------------------------------------------------------===// + +void AIGDialect::initialize() { + // Register operations. + addOperations< +#define GET_OP_LIST +#include "circt/Dialect/AIG/AIG.cpp.inc" + >(); +} + +#include "circt/Dialect/AIG/AIGDialect.cpp.inc" diff --git a/lib/Dialect/AIG/AIGOps.cpp b/lib/Dialect/AIG/AIGOps.cpp new file mode 100644 index 000000000000..b634243096b1 --- /dev/null +++ b/lib/Dialect/AIG/AIGOps.cpp @@ -0,0 +1,141 @@ +//===- LoopScheduleOps.cpp - LoopSchedule CIRCT Operations ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implement the AIG ops. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace circt; +using namespace circt::aig; + +#define GET_OP_CLASSES +#include "circt/Dialect/AIG/AIG.cpp.inc" + +OpFoldResult AndInverterOp::fold(FoldAdaptor adaptor) { + // TODO: Implement this. + return {}; +} + +LogicalResult AndInverterOp::canonicalize(AndInverterOp op, + PatternRewriter &rewriter) { + // TODO: Extend to more than 2 inputs. + if (op.getInputs().size() == 2) { + if ((op.getInputs()[0] == op.getInputs()[1]) && + (op.isInverted(0) == op.isInverted(1))) { + auto all = rewriter.create( + op.getLoc(), APInt::getAllOnes(op.getType().getIntOrFloatBitWidth())); + rewriter.replaceOpWithNewOp(op, op.getInputs()[0], + all, true, false); + + return success(); + } + } + return failure(); +} + +mlir::ParseResult AndInverterOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + SmallVector operands; + SmallVector inverts; + auto loc = parser.getCurrentLocation(); + + while (true) { + if (succeeded(parser.parseOptionalKeyword("not"))) { + inverts.push_back(true); + } else { + inverts.push_back(false); + } + operands.push_back(OpAsmParser::UnresolvedOperand()); + + if (parser.parseOperand(operands.back())) + return failure(); + if (parser.parseOptionalComma()) + break; + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return mlir::failure(); + + if (parser.parseColon()) + return mlir::failure(); + + mlir::Type resultRawType{}; + llvm::ArrayRef resultTypes(&resultRawType, 1); + + { + mlir::Type type; + if (parser.parseCustomTypeWithFallback(type)) + return mlir::failure(); + resultRawType = type; + } + + result.addTypes(resultTypes); + result.addAttribute("inverted", + parser.getBuilder().getDenseBoolArrayAttr(inverts)); + if (parser.resolveOperands(operands, resultTypes[0], loc, result.operands)) + return mlir::failure(); + return mlir::success(); +} + +void AndInverterOp::print(mlir::OpAsmPrinter &odsPrinter) { + odsPrinter << ' '; + llvm::interleaveComma(llvm::zip(getInverted(), getInputs()), odsPrinter, + [&](auto &&pair) { + auto [invert, input] = pair; + if (invert) { + odsPrinter << "not "; + } + odsPrinter << input; + }); + llvm::SmallVector elidedAttrs; + elidedAttrs.push_back("inverted"); + odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + odsPrinter << ' ' << ":"; + odsPrinter << ' '; + { + auto type = getResult().getType(); + if (auto validType = llvm::dyn_cast(type)) + odsPrinter.printStrippedAttrOrType(validType); + else + odsPrinter << type; + } +} + +APInt AndInverterOp::evaluate(ArrayRef inputs) { + assert(inputs.size() == getNumOperands() && + "Expected as many inputs as operands"); + assert(inputs.size() != 0 && "Expected non-empty input list"); + APInt result = APInt::getAllOnes(inputs.front().getBitWidth()); + for (auto [idx, input] : llvm::enumerate(inputs)) { + if (isInverted(idx)) + result &= ~input; + else + result &= input; + } + return result; +} + +void CutOp::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, ValueRange inputs, + std::function ctor) { + OpBuilder::InsertionGuard guard(builder); + + auto *block = builder.createBlock(result.addRegion()); + result.addTypes(resultTypes); + result.addOperands(inputs); + for (auto input : inputs) + block->addArgument(input.getType(), input.getLoc()); + + if (ctor) + ctor(); +} diff --git a/lib/Dialect/AIG/CMakeLists.txt b/lib/Dialect/AIG/CMakeLists.txt new file mode 100644 index 000000000000..6b6d4ca0766c --- /dev/null +++ b/lib/Dialect/AIG/CMakeLists.txt @@ -0,0 +1,17 @@ +add_circt_dialect_library(CIRCTAIG + AIGOps.cpp + AIGDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/AIG + + LINK_LIBS PUBLIC + MLIRIR + CIRCTHW + + DEPENDS + CIRCTHW + MLIRAIGIncGen +) + +add_subdirectory(Transforms) \ No newline at end of file diff --git a/lib/Dialect/AIG/Transforms/CMakeLists.txt b/lib/Dialect/AIG/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..17cf02ddc6b4 --- /dev/null +++ b/lib/Dialect/AIG/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_circt_dialect_library(CIRCTAIGTransforms + GreedyCutDecomp.cpp + LowerCutToLUT.cpp + LowerVariadic.cpp + LowerWordToBits.cpp + + DEPENDS + CIRCTAIGTransformsIncGen + + LINK_LIBS PUBLIC + CIRCTAIG + CIRCTComb + CIRCTHW +) diff --git a/lib/Dialect/AIG/Transforms/GreedyCutDecomp.cpp b/lib/Dialect/AIG/Transforms/GreedyCutDecomp.cpp new file mode 100644 index 000000000000..d38cb50d1373 --- /dev/null +++ b/lib/Dialect/AIG/Transforms/GreedyCutDecomp.cpp @@ -0,0 +1,297 @@ +//===- GreedyCutDecomp.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass performs cut decomposition on AIGs based on a naive greedy +// algorithm. We first convert all `aig.and_inv` to `aig.cut` that has a single +// operation and then try to merge cut operations on inputs. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/AIG/AIGPasses.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "aig-greedy-cut-decomp" + +namespace circt { +namespace aig { +#define GEN_PASS_DEF_GREEDYCUTDECOMP +#include "circt/Dialect/AIG/AIGPasses.h.inc" +} // namespace aig +} // namespace circt + +using namespace circt; +using namespace aig; + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { + +struct AndInverterOpToCutPattern : public mlir::RewritePattern { + AndInverterOpToCutPattern(MLIRContext *context) + : mlir::RewritePattern(AndInverterOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *opT, + PatternRewriter &rewriter) const override { + auto op = cast(opT); + + if (isa(op->getParentOp())) + return failure(); + + auto cutOp = rewriter.create( + op.getLoc(), op.getResult().getType(), op.getInputs(), [&]() { + auto result = rewriter.create( + op.getLoc(), op.getResult().getType(), + rewriter.getBlock()->getArguments(), op.getInvertedAttr()); + rewriter.create(op.getLoc(), ValueRange{result}); + }); + + rewriter.replaceOp(op, cutOp); + return success(); + } +}; + +static aig::CutOp mergeCuts(Location loc, MutableArrayRef cuts, + ArrayRef inputs, Value output, + PatternRewriter &rewriter) { + if (!mlir::computeTopologicalSorting(cuts)) + return {}; + + assert(cuts.size() >= 2); + + DenseMap valueToNewValue, inputsToBlockArg; + auto cutOp = + rewriter.create(loc, output.getType(), inputs, [&]() { + for (auto [i, input] : llvm::enumerate(inputs)) + inputsToBlockArg[input] = rewriter.getBlock()->getArgument(i); + + for (auto [i, cut] : llvm::enumerate(cuts)) { + auto cutOp = cast(cut); + assert(cutOp.getNumResults() == 1); + for (auto [arg, input] : llvm::zip( + cutOp.getBodyBlock()->getArguments(), cutOp.getInputs())) { + auto it = inputsToBlockArg.find(input); + if (it != inputsToBlockArg.end()) { + arg.replaceAllUsesWith(it->second); + } else { + auto cutOp = dyn_cast(input.getDefiningOp()); + assert(cutOp && cutOp.getNumResults() == 1); + arg.replaceAllUsesWith(valueToNewValue.at(input)); + } + } + + assert(cutOp.getNumResults() == 1); + valueToNewValue[cutOp.getResult(0)] = + cutOp.getBodyBlock()->getTerminator()->getOperand(0); + } + }); + + output.replaceAllUsesWith(cutOp.getResult(0)); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(cutOp.getBodyBlock()); + rewriter.create(loc, ValueRange{valueToNewValue.at(output)}); + } + + for (auto oldCut : llvm::reverse(cuts)) { + auto oldCutOutput = + cast(oldCut).getBodyBlock()->getTerminator(); + oldCutOutput->erase(); + cutOp.getBodyBlock()->getOperations().splice( + cutOp.getBodyBlock()->begin(), + cast(oldCut).getBodyBlock()->getOperations()); + rewriter.eraseOp(oldCut); + } + + return cutOp; +} + +struct MergeCutPattern : public mlir::RewritePattern { + MergeCutPattern(MLIRContext *context, unsigned cutLimit) + : mlir::RewritePattern(aig::CutOp::getOperationName(), 1, context), + cutLimit(cutLimit) {} + const unsigned cutLimit; + + LogicalResult matchAndRewrite(Operation *opT, + PatternRewriter &rewriter) const override { + auto op = cast(opT); + if (op.getNumOperands() >= cutLimit) + return failure(); + + // FIXME: This is really really dumb, but it is just a proof of concept. + SmallVector inputCuts; + { + SetVector inputCutsSet; + inputCutsSet.insert(op); + for (auto cut : op.getOperands()) { + if (auto cutOp = cut.getDefiningOp(); + cutOp && cutOp.getNumResults() == 1) + inputCutsSet.insert(cutOp); + } + + inputCuts = std::move(inputCutsSet.takeVector()); + } + + if (inputCuts.size() <= 1) + return failure(); + + SmallVector inputs; + inputs.reserve(inputCuts.size()); + + // This is naive implementation of the cut emuration of the local inputs. + // This can be done more efficiently. + + LLVM_DEBUG(llvm::dbgs() << "Trying to merge " << op << "\n"); + for (unsigned i = (1 << (inputCuts.size() - 1)) - 1; i != 0; --i) { + auto checkSubsetMerge = [&](unsigned i) -> LogicalResult { + SetVector inValues; + llvm::SmallDenseSet cutSet; + SmallVector outValues; + + for (unsigned j = 0; j < inputCuts.size(); ++j) { + if (i & (1 << j)) { + cutSet.insert(inputCuts[j]); + for (auto in : inputCuts[j].getInputs()) + inValues.insert(in); + outValues.push_back(inputCuts[j].getResult(0)); + LLVM_DEBUG(llvm::dbgs() << "Added " << inputCuts[j] << "\n"); + } + } + + Value singleOutput; + for (auto out : outValues) { + // Users of cuts must be closed under the cut set. + bool isClosed = + llvm::all_of(out.getUsers(), [&cutSet](Operation *user) { + if (auto cutOp = dyn_cast(user)) + return cutSet.contains(cutOp); + return false; + }); + inValues.remove(out); + if (!isClosed) { + if (singleOutput) { + LLVM_DEBUG(llvm::dbgs() << "Not closed\n"); + return failure(); + } + + singleOutput = out; + } + } + + if (!singleOutput || inValues.size() > cutLimit) { + LLVM_DEBUG(llvm::dbgs() << "Limit exceeded\n"); + return failure(); + } + + SmallVector subsetCuts; + for (unsigned j = 0; j < inputCuts.size(); ++j) { + if (i & (1 << j)) + subsetCuts.push_back(inputCuts[j]); + } + + // Ok, let's merge the cuts. + auto cutOp = mergeCuts(op.getLoc(), subsetCuts, inValues.takeVector(), + singleOutput, rewriter); + if (!cutOp) { + LLVM_DEBUG(llvm::dbgs() << "Failed to merge\n"); + return failure(); + } + return success(); + }; + + // Always enable a bit for 0 (the original cut). + auto result = checkSubsetMerge((i << 1) | 1); + if (succeeded(result)) + return result; + } + return failure(); + } +}; + +struct SinkConstantPattern : public mlir::RewritePattern { + SinkConstantPattern(MLIRContext *context) + : mlir::RewritePattern(aig::CutOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *opT, + PatternRewriter &rewriter) const override { + auto op = cast(opT); + SmallVector oldInputs, oldArgs; + auto *block = op.getBodyBlock(); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(block); + BitVector eraseArgs(block->getNumArguments()); + bool changed = false; + for (auto [i, in] : llvm::enumerate(op.getInputs())) { + if (auto constOp = in.getDefiningOp()) { + eraseArgs.set(i); + auto cloned = rewriter.clone(*constOp); + block->getArgument(i).replaceAllUsesWith(cloned->getResult(0)); + changed = true; + } else { + oldInputs.push_back(in); + oldArgs.push_back(block->getArgument(i)); + } + } + if (!changed) + return failure(); + block->eraseArguments(eraseArgs); + } + + auto newCut = rewriter.create(op.getLoc(), op.getResultTypes(), + oldInputs, [&]() {}); + + for (auto [newArg, oldArg] : + llvm::zip(newCut.getBodyBlock()->getArguments(), oldArgs)) + oldArg.replaceAllUsesWith(newArg); + + newCut.getBodyBlock()->getOperations().splice( + newCut.getBodyBlock()->begin(), block->getOperations()); + rewriter.replaceOp(op, newCut); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert Comb to AIG pass +//===----------------------------------------------------------------------===// + +namespace { +struct GreedyCutDecompPass + : public impl::GreedyCutDecompBase { + using GreedyCutDecompBase::GreedyCutDecompBase; + void runOnOperation() override; + using GreedyCutDecompBase::cutSizes; +}; +} // namespace + +void GreedyCutDecompPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext(), cutSizes.getValue()); + mlir::FrozenRewritePatternSet frozen(std::move(patterns)); + mlir::GreedyRewriteConfig config; + + config.useTopDownTraversal = true; + + if (failed( + mlir::applyPatternsAndFoldGreedily(getOperation(), frozen, config))) + return signalPassFailure(); +} + +std::unique_ptr +aig::createGreedyCutDecompPass(const GreedyCutDecompOptions &options) { + return std::make_unique(options); +} diff --git a/lib/Dialect/AIG/Transforms/LowerCutToLUT.cpp b/lib/Dialect/AIG/Transforms/LowerCutToLUT.cpp new file mode 100644 index 000000000000..120da49da0a2 --- /dev/null +++ b/lib/Dialect/AIG/Transforms/LowerCutToLUT.cpp @@ -0,0 +1,150 @@ +//===- LowerCutToLUT.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers `aig.cut` to `comb.truth_table` with k inputs where +// k is the size of the cut (= operand inputs size). +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/AIG/AIGPasses.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "aig-lower-cut-to-lut" + +namespace circt { +namespace aig { +#define GEN_PASS_DEF_LOWERCUTTOLUT +#include "circt/Dialect/AIG/AIGPasses.h.inc" +} // namespace aig +} // namespace circt + +using namespace circt; +using namespace aig; + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { + +struct CutToLUTPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CutOp op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult +CutToLUTPattern::matchAndRewrite(CutOp cutOp, PatternRewriter &rewriter) const { + if (cutOp.getNumResults() == 0) { + rewriter.eraseOp(cutOp); + return success(); + } + + // 1. Check if the cut can be lowered to a LUT. + + // Check if the cut has a single output. + if (cutOp.getNumResults() != 1) + return cutOp.emitError("expected single output"); + + // Check if every type is i1. + for (auto operand : cutOp.getOperands()) { + if (operand.getType() != rewriter.getI1Type()) + return cutOp.emitError("expected i1 type"); + } + + for (auto result : cutOp.getResults()) { + if (result.getType() != rewriter.getI1Type()) + return cutOp.emitError("expected i1 type"); + } + + uint32_t lutWidth = cutOp.getNumOperands(); + if (lutWidth >= 32) + return cutOp.emitError("Cut width is too large to fit in a LUT"); + + // 2. Lower the cut to a LUT. We can get a truth table by evaluating the cut + // body with every possible combination of the input values. + uint32_t tableSize = 1 << lutWidth; + DenseMap mapping; + auto &body = cutOp.getBodyRegion().front(); + for (uint32_t i = 0; i < lutWidth; i++) { + APInt value(tableSize, 0); + for (uint32_t j = 0; j < tableSize; j++) { + // Make sure the order of the bits is correct. + value.setBitVal(j, (j >> i) & 1); + } + mapping[body.getArgument(i)] = std::move(value); + } + + // Evaluate the cut body. Update `mapping` along the way. + for (auto &op : body.getOperations()) { + if (auto constOp = dyn_cast(&op)) { + mapping[constOp.getResult()] = + APInt(tableSize, constOp.getValue().getZExtValue()); + } else if (auto AndInverterOp = dyn_cast(&op)) { + // TODO: Avoid this copy. + SmallVector inputs; + for (auto input : AndInverterOp.getInputs()) + inputs.push_back(mapping[input]); + mapping[AndInverterOp.getResult()] = AndInverterOp.evaluate(inputs); + } else if (auto outputOp = dyn_cast(&op)) { + assert(outputOp.getOutputs().size() == 1 && "expected single output"); + auto value = mapping.at(outputOp.getOutputs().front()); + LLVM_DEBUG(llvm::dbgs() << "value: " << value << "\n"); + SmallVector bits; + bits.reserve(tableSize); + for (uint32_t i = 0; i < tableSize; i++) + bits.push_back(value[i]); + auto truthTable = rewriter.create( + op.getLoc(), cutOp.getOperands(), rewriter.getBoolArrayAttr(bits)); + rewriter.replaceOp(cutOp, truthTable); + return success(); + } else { + return op.emitError("unsupported operation in Eval: ") << op; + } + } + + // It should not reach here. + return failure(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Convert Comb to AIG pass +//===----------------------------------------------------------------------===// + +static void populateLowerCutToLUTPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +namespace { +struct LowerCutToLUTPass : public impl::LowerCutToLUTBase { + void runOnOperation() override; +}; +} // namespace + +void LowerCutToLUTPass::runOnOperation() { + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + + RewritePatternSet patterns(&getContext()); + populateLowerCutToLUTPatterns(patterns); + + if (failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr aig::createLowerCutToLUTPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/AIG/Transforms/LowerVariadic.cpp b/lib/Dialect/AIG/Transforms/LowerVariadic.cpp new file mode 100644 index 000000000000..d8a38df0b21e --- /dev/null +++ b/lib/Dialect/AIG/Transforms/LowerVariadic.cpp @@ -0,0 +1,117 @@ +//===- LowerVariadic.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers variadic AndInverter operations to binary AndInverter +// operations. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/AIG/AIGPasses.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "aig-lower-variadic" + +namespace circt { +namespace aig { +#define GEN_PASS_DEF_LOWERVARIADIC +#include "circt/Dialect/AIG/AIGPasses.h.inc" +} // namespace aig +} // namespace circt + +using namespace circt; +using namespace aig; + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { +static Value lowerFullyAssociativeOp(AndInverterOp op, OperandRange operands, + ArrayRef inverts, + ConversionPatternRewriter &rewriter) { + Value lhs, rhs; + switch (operands.size()) { + case 0: + assert(0 && "cannot be called with empty operand range"); + break; + case 1: + if (inverts[0]) + return rewriter.create(op.getLoc(), operands[0], true); + else + return operands[0]; + case 2: + lhs = operands[0]; + rhs = operands[1]; + return rewriter.create(op.getLoc(), lhs, rhs, inverts[0], + inverts[1]); + default: + auto firstHalf = operands.size() / 2; + lhs = lowerFullyAssociativeOp(op, operands.take_front(firstHalf), + inverts.take_front(firstHalf), rewriter); + rhs = lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), + inverts.drop_front(firstHalf), rewriter); + return rewriter.create(op.getLoc(), lhs, rhs); + } +} + +struct VariadicOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AndInverterOp::Adaptor; + LogicalResult + matchAndRewrite(AndInverterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getInputs().size() < 2) + return failure(); + + // TODO: This is naive implementatino that creates a balanced binary trees. + // WE can improve by analyzing the dataflow and creating a tree that + // improves the critical path or area. + auto result = lowerFullyAssociativeOp(op, op.getOperands(), + op.getInverted(), rewriter); + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Convert Comb to AIG pass +//===----------------------------------------------------------------------===// + +static void populateLowerVariadicPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +namespace { +struct LowerVariadicPass : public impl::LowerVariadicBase { + void runOnOperation() override; +}; +} // namespace + +void LowerVariadicPass::runOnOperation() { + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + target.addDynamicallyLegalOp( + [](AndInverterOp op) { return op.getInputs().size() <= 2; }); + + RewritePatternSet patterns(&getContext()); + populateLowerVariadicPatterns(patterns); + + if (failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr aig::createLowerVariadicPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp b/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp new file mode 100644 index 000000000000..cc4aa6c20324 --- /dev/null +++ b/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp @@ -0,0 +1,100 @@ +//===- LowerWordToBits.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers multi-bit AIG operations to single-bit ones. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/AIG/AIGPasses.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "aig-lower-word-to-bits" + +namespace circt { +namespace aig { +#define GEN_PASS_DEF_LOWERWORDTOBITS +#include "circt/Dialect/AIG/AIGPasses.h.inc" +} // namespace aig +} // namespace circt + +using namespace circt; +using namespace aig; + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { + +struct WordConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AndInverterOp::Adaptor; + LogicalResult + matchAndRewrite(AndInverterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getType().getIntOrFloatBitWidth() <= 1) + return failure(); + + SmallVector result; + for (size_t i = 0; i < op.getType().getIntOrFloatBitWidth(); i++) { + SmallVector operands; + for (auto operand : op.getOperands()) { + auto operandBits = + rewriter.create(op.getLoc(), operand, i, 1); + operands.push_back(operandBits); + } + auto andInverter = rewriter.create(op.getLoc(), operands, + op.getInvertedAttr()); + result.push_back(andInverter); + } + + auto concat = rewriter.create(op.getLoc(), result); + rewriter.replaceOp(op, concat); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Convert Comb to AIG pass +//===----------------------------------------------------------------------===// + +static void populateLowerWordToBitsPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +namespace { +struct LowerWordToBitsPass + : public impl::LowerWordToBitsBase { + void runOnOperation() override; +}; +} // namespace + +void LowerWordToBitsPass::runOnOperation() { + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](AndInverterOp op) { + return op.getType().getIntOrFloatBitWidth() <= 1; + }); + + RewritePatternSet patterns(&getContext()); + populateLowerWordToBitsPatterns(patterns); + + if (failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr aig::createLowerWordToBitsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 82ff4d0d835e..eb30485abf79 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -9,6 +9,7 @@ ## ##===----------------------------------------------------------------------===// +add_subdirectory(AIG) add_subdirectory(Arc) add_subdirectory(Calyx) add_subdirectory(Comb) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0661355d1b98..3bc76409af63 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -26,6 +26,7 @@ set(CIRCT_TEST_DEPENDS circt-dis circt-lec circt-opt + circt-synth circt-test circt-translate circt-reduce diff --git a/test/Conversion/CombToAIG/comb-to-aig.mlir b/test/Conversion/CombToAIG/comb-to-aig.mlir new file mode 100644 index 000000000000..de740b40e192 --- /dev/null +++ b/test/Conversion/CombToAIG/comb-to-aig.mlir @@ -0,0 +1,9 @@ +// RUN: circt-opt %s --convert-comb-to-aig | FileCheck %s + +// CHECK-LABEL: @test +hw.module @test(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32, in %arg4: i1, out out: i32) { + // CHECK-NEXT: %0 = aig.and_inv not %arg0, not %arg1, not %arg2, not %arg3 : i32 + // CHECK-NEXT: %1 = aig.and_inv not %0 : i32 + %0 = comb.or %arg0, %arg1, %arg2, %arg3 : i32 + hw.output %0 : i32 +} \ No newline at end of file diff --git a/test/Dialect/AIG/greedy-decomp.mlir b/test/Dialect/AIG/greedy-decomp.mlir new file mode 100644 index 000000000000..9827b0394dea --- /dev/null +++ b/test/Dialect/AIG/greedy-decomp.mlir @@ -0,0 +1,26 @@ +// RUN: circt-opt %s --aig-greedy-cut-decomp | FileCheck %s + +// CHECK-LABEL: @variadic +hw.module @variadic(in %a : i1, in %b : i1, in %c : i1, in %d : i1, in %e : i1, in %f : i1, in %g : i1, out and6 : i1) { + %0 = aig.and_inv %b, %c : i1 + %1 = aig.and_inv %a, %0 : i1 + %2 = aig.and_inv %d, %e : i1 + %3 = aig.and_inv %f, %g : i1 + %4 = aig.and_inv %2, %3 : i1 + %5 = aig.and_inv %1, %4 : i1 + hw.output %5 : i1 +} +// CHECK-NEXT: %0 = aig.cut %a, %b, %c : (i1, i1, i1) -> i1 { +// CHECK-NEXT: ^bb0(%arg0: i1, %arg1: i1, %arg2: i1): +// CHECK-NEXT: %2 = aig.and_inv %arg1, %arg2 : i1 +// CHECK-NEXT: %3 = aig.and_inv %arg0, %2 : i1 +// CHECK-NEXT: aig.output %3 : i1 +// CHECK-NEXT: } +// CHECK-NEXT: %1 = aig.cut %0, %d, %e, %f, %g : (i1, i1, i1, i1, i1) -> i1 { +// CHECK-NEXT: ^bb0(%arg0: i1, %arg1: i1, %arg2: i1, %arg3: i1, %arg4: i1): +// CHECK-NEXT: %2 = aig.and_inv %arg1, %arg2 : i1 +// CHECK-NEXT: %3 = aig.and_inv %arg3, %arg4 : i1 +// CHECK-NEXT: %4 = aig.and_inv %2, %3 : i1 +// CHECK-NEXT: %5 = aig.and_inv %arg0, %4 : i1 +// CHECK-NEXT: aig.output %5 : i1 +// CHECK-NEXT: } diff --git a/test/Dialect/AIG/lower-cut-to-lut.mlir b/test/Dialect/AIG/lower-cut-to-lut.mlir new file mode 100644 index 000000000000..71e5514784c5 --- /dev/null +++ b/test/Dialect/AIG/lower-cut-to-lut.mlir @@ -0,0 +1,15 @@ +// RUN: circt-opt %s --aig-lower-cut-to-lut | FileCheck %s +// CHECK: hw.module @Cut +hw.module @Cut(in %a: i1, in %b: i1, in %c: i1, in %d: i1, out e: i1) { + // CHECK-NEXT: %0 = comb.truth_table %a, %b, %c, %d + // CHECK-SAME: -> [false, false, false, false, false, true, true, true, + // CHECK-SAME: false, false, false, false, false, false, false, false] + %0 = aig.cut %a, %b, %c, %d : (i1, i1, i1, i1) -> (i1) { + ^bb0(%arg0: i1, %arg1: i1, %arg2: i1, %arg3: i1): + %1 = aig.and_inv not %arg0, not %arg1 : i1 + %2 = aig.and_inv %arg2, not %arg3 : i1 + %3 = aig.and_inv not %1, %2 : i1 + aig.output %3 : i1 + } + hw.output %0 : i1 +} diff --git a/test/Dialect/AIG/lower-variadic.mlir b/test/Dialect/AIG/lower-variadic.mlir new file mode 100644 index 000000000000..df1dca9279eb --- /dev/null +++ b/test/Dialect/AIG/lower-variadic.mlir @@ -0,0 +1,11 @@ +// RUN: circt-opt %s --aig-lower-variadic | FileCheck %s +// CHECK: hw.module @Basic +hw.module @Basic(in %a: i2, in %b: i2, in %c: i2, in %d: i2, in %e: i2, out f: i2) { + // CHECK: %0 = aig.and_inv not %a, %b : i2 + // CHECK-NEXT: %1 = aig.and_inv not %d, %e : i2 + // CHECK-NEXT: %2 = aig.and_inv %c, %1 : i2 + // CHECK-NEXT: %3 = aig.and_inv %0, %2 : i2 + // CHECK-NEXT: hw.output %3 : i2 + %0 = aig.and_inv not %a, %b, %c, not %d, %e : i2 + hw.output %0 : i2 +} diff --git a/test/Dialect/AIG/lower-word-to-bits.mlir b/test/Dialect/AIG/lower-word-to-bits.mlir new file mode 100644 index 000000000000..9202a8c6eb98 --- /dev/null +++ b/test/Dialect/AIG/lower-word-to-bits.mlir @@ -0,0 +1,14 @@ +// RUN: circt-opt %s --aig-lower-word-to-bits | FileCheck %s +// CHECK: hw.module @Basic +hw.module @Basic(in %a: i2, in %b: i2, out f: i2) { + // CHECK-NEXT: %0 = comb.extract %a from 0 : (i2) -> i1 + // CHECK-NEXT: %1 = comb.extract %b from 0 : (i2) -> i1 + // CHECK-NEXT: %2 = aig.and_inv not %0, %1 : i1 + // CHECK-NEXT: %3 = comb.extract %a from 1 : (i2) -> i1 + // CHECK-NEXT: %4 = comb.extract %b from 1 : (i2) -> i1 + // CHECK-NEXT: %5 = aig.and_inv not %3, %4 : i1 + // CHECK-NEXT: %6 = comb.concat %2, %5 : i1, i1 + // CHECK-NEXT: hw.output %6 : i2 + %0 = aig.and_inv not %a, %b : i2 + hw.output %0 : i2 +} diff --git a/test/Dialect/AIG/round-trip.mlir b/test/Dialect/AIG/round-trip.mlir new file mode 100644 index 000000000000..78f9e6bd8574 --- /dev/null +++ b/test/Dialect/AIG/round-trip.mlir @@ -0,0 +1,21 @@ +// RUN: circt-opt %s | circt-opt | FileCheck %s + +// CHECK-LABEL: @And +// CHECK-NEXT: aig.and_inv %b, %b : i4 +// CHECK-NEXT: aig.and_inv %b, not %b : i4 +// CHECK-NEXT: aig.and_inv not %a, not %a : i1 +hw.module @And(in %a: i1, in %b: i4) { + %0 = aig.and_inv %b, %b : i4 + %1 = aig.and_inv %b, not %b : i4 + %2 = aig.and_inv not %a, not %a : i1 +} + +hw.module @Cut(in %a: i1, in %b: i1, out c: i1, out d: i1) { + %0, %1 = aig.cut %a, %b : (i1, i1) -> (i1, i1) { + ^bb0(%arg0: i1, %arg1: i1): + %c = aig.and_inv %arg0, not %arg1 : i1 + %d = aig.and_inv %arg0, %arg1 : i1 + aig.output %c, %d : i1, i1 + } + hw.output %0, %1 : i1, i1 +} diff --git a/test/circt-synth/basic.mlir b/test/circt-synth/basic.mlir new file mode 100644 index 000000000000..82c96020f9e3 --- /dev/null +++ b/test/circt-synth/basic.mlir @@ -0,0 +1,46 @@ +// RUN: circt-synth %s | FileCheck %s + +// CHECK-LABEL: @and +hw.module @and(in %a: i1, in %b: i1, out and: i1) { + %0 = comb.and %a, %b : i1 + // CHECK-NEXT: %0 = comb.truth_table %a, %b -> [false, false, false, true] + hw.output %0 : i1 +} + +// CHECK-LABEL: @or +hw.module @or(in %a: i1, in %b: i1, out or: i1) { + %0 = comb.or %a, %b : i1 + // CHECK-NEXT: %0 = comb.truth_table %a, %b -> [false, true, true, true] + hw.output %0 : i1 +} + +// CHECK-LABEL: @xor +hw.module @xor(in %a: i1, in %b: i1, out xor: i1) { + %0 = comb.xor %a, %b : i1 + // CHECK-NEXT: %0 = comb.truth_table %a, %b -> [false, true, true, false] + hw.output %0 : i1 +} + +// CHECK-LABEL: @multibit +hw.module @multibit(in %a: i2, in %b: i2, out and: i2) { + %0 = comb.and %a, %b : i2 + // CHCK-NEXT: %0 = comb.extract %a from 0 : (i2) -> i1 + // CHCK-NEXT: %1 = comb.extract %b from 0 : (i2) -> i1 + // CHCK-NEXT: %2 = comb.truth_table %0, %1 -> [false, false, false, true] + // CHCK-NEXT: %3 = comb.extract %a from 1 : (i2) -> i1 + // CHCK-NEXT: %4 = comb.extract %b from 1 : (i2) -> i1 + // CHCK-NEXT: %5 = comb.truth_table %3, %4 -> [false, false, false, true] + // CHCK-NEXT: %6 = comb.concat %2, %5 : i1, i1 + hw.output %0 : i2 +} + +// CHECK-LABEL: @variadic +hw.module @variadic(in %a: i1, in %b: i1, in %c: i1, + in %d: i1, in %e: i1, in %f: i1, out and6: i1) { + %0 = comb.and %a, %b, %c, %d, %e, %f : i1 + // CHECK-NEXT: %0 = comb.truth_table %a, %b, %c, %d, %e, %f -> [ + // CHECK-COUNT-63: false + // CHECK-SAME: true + // CHECK-SAME: ] + hw.output %0 : i1 +} \ No newline at end of file diff --git a/test/circt-synth/commandline.mlir b/test/circt-synth/commandline.mlir new file mode 100644 index 000000000000..9673e7c89a2a --- /dev/null +++ b/test/circt-synth/commandline.mlir @@ -0,0 +1,3 @@ +// RUN: circt-synth --help | FileCheck %s + +// CHECK: OVERVIEW: Logic synthesis tool diff --git a/test/circt-synth/lut-size.mlir b/test/circt-synth/lut-size.mlir new file mode 100644 index 000000000000..222d8d947089 --- /dev/null +++ b/test/circt-synth/lut-size.mlir @@ -0,0 +1,24 @@ +// RUN: circt-synth %s --lut-size=2 | FileCheck %s --check-prefix=LUT-2 +// RUN: circt-synth %s --lut-size=4 | FileCheck %s --check-prefix=LUT-4 +// RUN: circt-synth %s --lut-size=6 | FileCheck %s --check-prefix=LUT-6 + +// LUT-2-LABEL: @variadic +// LUT-4-LABEL: @variadic +// LUT-6-LABEL: @variadic +hw.module @variadic(in %a: i1, in %b: i1, in %c: i1, + in %d: i1, in %e: i1, in %f: i1, out and6: i1) { + // LUT-2-NEXT: %0 = comb.truth_table %b, %c -> [false, false, false, true] + // LUT-2-NEXT: %1 = comb.truth_table %a, %0 -> [false, false, false, true] + // LUT-2-NEXT: %2 = comb.truth_table %e, %f -> [false, false, false, true] + // LUT-2-NEXT: %3 = comb.truth_table %d, %2 -> [false, false, false, true] + // LUT-2-NEXT: %4 = comb.truth_table %1, %3 -> [false, false, false, true] + // LUT-2-NEXT: hw.output %4 + // LUT-4-NEXT: %0 = comb.truth_table %a, %b, %c -> [false, false, false, false, false, false, false, true] + // LUT-4-NEXT: %1 = comb.truth_table %0, %d, %e, %f -> [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true] + // LUT-4-NEXT: hw.output %1 + // LUT-6-NEXT: %0 = comb.truth_table + // LUT-6-NEXT: hw.output %0 + %0 = comb.and %a, %b, %c, %d, %e, %f : i1 + + hw.output %0 : i1 +} diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 479f97da8a1c..86eb72502280 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -60,8 +60,8 @@ tools = [ 'arcilator', 'circt-as', 'circt-capi-ir-test', 'circt-capi-om-test', 'circt-capi-firrtl-test', 'circt-capi-firtool-test', 'circt-dis', - 'circt-lec', 'circt-reduce', 'circt-test', 'circt-translate', 'firtool', - 'hlstool', 'om-linker', 'ibistool' + 'circt-lec', 'circt-reduce', 'circt-test', 'circt-translate', 'circt-synth', + 'firtool', 'hlstool', 'om-linker', 'ibistool' ] if "CIRCT_OPT_CHECK_IR_ROUNDTRIP" in os.environ: diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 738513fc547c..a91b167fa521 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(circt-lsp-server) add_subdirectory(circt-opt) add_subdirectory(circt-reduce) add_subdirectory(circt-rtl-sim) +add_subdirectory(circt-synth) add_subdirectory(circt-test) add_subdirectory(circt-translate) add_subdirectory(firtool) diff --git a/tools/circt-synth/CMakeLists.txt b/tools/circt-synth/CMakeLists.txt new file mode 100644 index 000000000000..69af1dbdc0e4 --- /dev/null +++ b/tools/circt-synth/CMakeLists.txt @@ -0,0 +1,20 @@ +set(LLVM_LINK_COMPONENTS + Support +) + +add_circt_tool(circt-synth circt-synth.cpp) +target_link_libraries(circt-synth + PRIVATE + CIRCTCombToAIG + CIRCTAIGTransforms + CIRCTAIG + CIRCTComb + CIRCTHW + CIRCTSupport + MLIRIR + MLIRParser + LLVMSupport +) + +llvm_update_compile_flags(circt-synth) +mlir_check_all_link_libraries(circt-synth) diff --git a/tools/circt-synth/circt-synth.cpp b/tools/circt-synth/circt-synth.cpp new file mode 100644 index 000000000000..3e5e3449340b --- /dev/null +++ b/tools/circt-synth/circt-synth.cpp @@ -0,0 +1,218 @@ +//===- circt-synth.cpp - The circt-synth driver -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// This file initiliazes the 'circt-synth' tool, which performs logic +/// synthesis. Currently, it only performs backend-agnostic FPGA synthesis, +/// mapping core dialects into FPGA-specific primitives, such as LUTs. +/// +//===----------------------------------------------------------------------===// + +#include "circt/Conversion/CombToAIG.h" +#include "circt/Dialect/AIG/AIGDialect.h" +#include "circt/Dialect/AIG/AIGPasses.h" +#include "circt/Dialect/Comb/CombDialect.h" +#include "circt/Dialect/HW/HWDialect.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Support/Passes.h" +#include "circt/Support/Version.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" + +namespace cl = llvm::cl; + +using namespace mlir; +using namespace circt; + +//===----------------------------------------------------------------------===// +// Command-line options declaration +//===----------------------------------------------------------------------===// + +static cl::OptionCategory mainCategory("circt-synth Options"); + +static cl::opt inputFilename(cl::Positional, cl::Required, + cl::desc("Specify an input file"), + cl::value_desc("filename"), + cl::cat(mainCategory)); + +static cl::opt outputFilename("o", cl::desc("Output filename"), + cl::value_desc("filename"), + cl::init("-"), + cl::cat(mainCategory)); + +static cl::opt + verifyPasses("verify-each", + cl::desc("Run the verifier after each transformation pass"), + cl::init(true), cl::cat(mainCategory)); + +static cl::opt + verbosePassExecutions("verbose-pass-executions", + cl::desc("Log executions of toplevel module passes"), + cl::init(false), cl::cat(mainCategory)); +// Options to control early-out from pipeline. +enum Until { UntilAIG, UntilAIGOpt, UntilAIGLowering, UntilEnd }; + +static auto runUntilValues = llvm::cl::values( + clEnumValN(UntilAIG, "aig", "Conversion of modules to AIG"), + clEnumValN(UntilAIGOpt, "aig-opt", "AIG optimizations"), + clEnumValN(UntilAIGLowering, "aig-lowering", + "AIG lowering to lut-mappable forms"), + clEnumValN(UntilEnd, "all", "After AIG to Comb lowering")); + +static llvm::cl::opt runUntilBefore( + "until-before", llvm::cl::desc("Stop pipeline before a specified point"), + runUntilValues, llvm::cl::init(UntilEnd), llvm::cl::cat(mainCategory)); +static llvm::cl::opt runUntilAfter( + "until-after", llvm::cl::desc("Stop pipeline after a specified point"), + runUntilValues, llvm::cl::init(UntilEnd), llvm::cl::cat(mainCategory)); + +// LUT-k parameter. This needs to be unifined to a more fine-grained target +// architecture information. +static cl::opt lutSize("lut-size", + cl::desc("Size of LUT to use for mapping"), + cl::init(6), cl::cat(mainCategory)); +//===----------------------------------------------------------------------===// +// Main Tool Logic +//===----------------------------------------------------------------------===// + +static bool untilReached(Until until) { + return until >= runUntilBefore || until > runUntilAfter; +} + +//===----------------------------------------------------------------------===// +// Tool implementation +//===-----------------------------------------------------------------------===// + +static void populateSynthesisPipeline(PassManager &pm) { + auto &mpm = pm.nest(); + mpm.addPass(createConvertCombToAIG()); + mpm.addPass(createCSEPass()); + if (untilReached(UntilAIG)) + return; + + mpm.addPass(createSimpleCanonicalizerPass()); + mpm.addPass(createCSEPass()); + + // TODO: Add balancing, rewriting, FRAIG conversion, etc. + if (untilReached(UntilAIGOpt)) + return; + + mpm.addPass(aig::createLowerVariadicPass()); + mpm.addPass(aig::createLowerWordToBitsPass()); + mpm.addPass(createCSEPass()); + mpm.addPass(createSimpleCanonicalizerPass()); + + // TODO: Add LUT mapping, etc. + if (untilReached(UntilAIGLowering)) + return; + + mpm.addPass(createCSEPass()); + aig::GreedyCutDecompOptions options; + options.cutSizes = lutSize; + mpm.addPass(aig::createGreedyCutDecompPass(options)); + mpm.addPass(aig::createLowerCutToLUTPass()); +} + +/// This functions initializes the various components of the tool and +/// orchestrates the work to be done. +static LogicalResult executeSynthesis(MLIRContext &context) { + // Create the timing manager we use to sample execution times. + DefaultTimingManager tm; + applyDefaultTimingManagerCLOptions(tm); + auto ts = tm.getRootScope(); + + OwningOpRef module; + { + auto parserTimer = ts.nest("Parse MLIR input"); + // Parse the provided input files. + module = parseSourceFile(inputFilename, &context); + } + if (!module) + return failure(); + // Create the output directory or output file depending on our mode. + std::optional> outputFile; + std::string errorMessage; + // Create an output file. + outputFile.emplace(openOutputFile(outputFilename, &errorMessage)); + if (!outputFile.value()) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + PassManager pm(&context); + pm.enableVerifier(verifyPasses); + pm.enableTiming(ts); + if (failed(applyPassManagerCLOptions(pm))) + return failure(); + + if (verbosePassExecutions) + pm.addInstrumentation( + std::make_unique>( + "circt-synth")); + populateSynthesisPipeline(pm); + if (failed(pm.run(module.get()))) + return failure(); + + auto timer = ts.nest("Print MLIR output"); + OpPrintingFlags printingFlags; + module->print(outputFile.value()->os(), printingFlags); + outputFile.value()->keep(); + return success(); +} + +/// The entry point for the `circt-synth` tool: +/// configures and parses the command-line options, +/// registers all dialects within a MLIR context, +/// and calls the `executeSynthesis` function to do the actual work. +int main(int argc, char **argv) { + llvm::InitLLVM y(argc, argv); + + // Hide default LLVM options, other than for this tool. + // MLIR options are added below. + cl::HideUnrelatedOptions(mainCategory); + + // Register any pass manager command line options. + registerMLIRContextCLOptions(); + registerPassManagerCLOptions(); + registerDefaultTimingManagerCLOptions(); + registerAsmPrinterCLOptions(); + cl::AddExtraVersionPrinter( + [](llvm::raw_ostream &os) { os << circt::getCirctVersion() << '\n'; }); + + // Parse the command-line options provided by the user. + cl::ParseCommandLineOptions(argc, argv, "Logic synthesis tool\n\n"); + + // Set the bug report message to indicate users should file issues on + // llvm/circt and not llvm/llvm-project. + llvm::setBugReportMsg(circt::circtBugReportMsg); + + // Register the supported CIRCT dialects and create a context to work with. + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + + // Setup of diagnostic handling. + llvm::SourceMgr sourceMgr; + SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + // Avoid printing a superfluous note on diagnostic emission. + context.printOpOnDiagnostic(false); + + // Perform the synthesis; using `exit` to avoid the slow + // teardown of the MLIR context. + exit(failed(executeSynthesis(context))); +}