diff --git a/include/circt/Dialect/Verif/Passes.td b/include/circt/Dialect/Verif/Passes.td index 6db636565855..28eef3679c7d 100644 --- a/include/circt/Dialect/Verif/Passes.td +++ b/include/circt/Dialect/Verif/Passes.td @@ -42,4 +42,12 @@ def LowerFormalToHWPass : Pass<"lower-formal-to-hw", "mlir::ModuleOp"> { }]; } +def LowerContractsPass : Pass<"lower-contracts", "mlir::ModuleOp"> { + let summary = "Lower contracts into formal tests"; + let description = [{ + Converts `hw.module` ops containing a `verif.contract` into a + `verif.formal` op. + }]; +} + #endif // CIRCT_DIALECT_VERIF_PASSES_TD diff --git a/lib/Dialect/Verif/Transforms/CMakeLists.txt b/lib/Dialect/Verif/Transforms/CMakeLists.txt index 3dc8b5d38a42..4048dbce6b3d 100644 --- a/lib/Dialect/Verif/Transforms/CMakeLists.txt +++ b/lib/Dialect/Verif/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_circt_dialect_library(CIRCTVerifTransforms VerifyClockedAssertLike.cpp PrepareForFormal.cpp LowerFormalToHW.cpp + LowerContracts.cpp DEPENDS CIRCTVerifTransformsIncGen diff --git a/lib/Dialect/Verif/Transforms/LowerContracts.cpp b/lib/Dialect/Verif/Transforms/LowerContracts.cpp new file mode 100644 index 000000000000..cd6e501bfc72 --- /dev/null +++ b/lib/Dialect/Verif/Transforms/LowerContracts.cpp @@ -0,0 +1,109 @@ +//===- LowerContracts.cpp - Formal Preparations --*- C++ -*----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===// +// +// Lower contracts into verif.formal tests. +// +//===----------------------------------------------------------------------===// +#include "circt/Dialect/Verif/VerifOps.h" +#include "circt/Dialect/Verif/VerifPasses.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace circt; + +namespace circt { +namespace verif { +#define GEN_PASS_DEF_LOWERCONTRACTSPASS +#include "circt/Dialect/Verif/Passes.h.inc" +} // namespace verif +} // namespace circt + +using namespace mlir; +using namespace verif; +using namespace hw; + +namespace { +struct LowerContractsPass + : verif::impl::LowerContractsPassBase { + void runOnOperation() override; +}; + +template +void replaceContractOp(PatternRewriter &rewriter, Block *body) { + for (auto op : llvm::make_early_inc_range(body->getOps())) { + auto label = op.getLabel(); + StringAttr labelAttr; + if (label) { + labelAttr = rewriter.getStringAttr(label.value()); + } + rewriter.replaceOpWithNewOp(op, op.getProperty(), op.getEnable(), + labelAttr); + } +} + +struct HWOpRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(HWModuleOp op, + PatternRewriter &rewriter) const override { + auto formalOp = rewriter.create( + op.getLoc(), op.getNameAttr(), rewriter.getDictionaryAttr({})); + + // Clone module body into fomal op body + rewriter.cloneRegionBefore(op.getRegion(), formalOp.getBody(), + formalOp.getBody().end()); + + auto *bodyBlock = &formalOp.getBody().front(); + + // Erase hw.output + rewriter.eraseOp(bodyBlock->getTerminator()); + + // Convert block args to symbolic values + rewriter.setInsertionPointToStart(bodyBlock); + for (auto arg : llvm::make_early_inc_range(bodyBlock->getArguments())) { + auto sym = + rewriter.create(arg.getLoc(), arg.getType()); + rewriter.replaceAllUsesWith(arg, sym); + } + bodyBlock->eraseArguments(0, bodyBlock->getNumArguments()); + + // Inline contract ops + for (auto contractOp : + llvm::make_early_inc_range(bodyBlock->getOps())) { + + // Convert ensure to assert, require to assume + rewriter.setInsertionPointToEnd(&contractOp.getBody().front()); + Block *contractBlock = &contractOp.getBody().front(); + replaceContractOp(rewriter, contractBlock); + replaceContractOp(rewriter, contractBlock); + + // Inline body + rewriter.inlineBlockBefore(&contractOp.getBody().front(), + &formalOp.getBody().front(), + formalOp.getBody().front().end()); + + // Replace results with inputs and erase + for (auto [input, result] : + llvm::zip(contractOp.getResults(), contractOp.getInputs())) { + rewriter.replaceAllUsesWith(input, result); + } + rewriter.eraseOp(contractOp); + } + + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +void LowerContractsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); +} diff --git a/llvm b/llvm index 3cc852ece438..8193832fb988 160000 --- a/llvm +++ b/llvm @@ -1 +1 @@ -Subproject commit 3cc852ece438a63e7b09d1c84a81d21598454e1a +Subproject commit 8193832fb988e3df1e8e726634783805dca8d9b6 diff --git a/test/Dialect/Verif/lower-contracts.mlir b/test/Dialect/Verif/lower-contracts.mlir new file mode 100644 index 000000000000..4efecfa3a0e0 --- /dev/null +++ b/test/Dialect/Verif/lower-contracts.mlir @@ -0,0 +1,118 @@ +// RUN: circt-opt --lower-contracts %s | FileCheck %s + + +// CHECK: verif.formal @Mul9 { +// CHECK: %c9_i42 = hw.constant 9 : i42 +// CHECK: %c3_i42 = hw.constant 3 : i42 +// CHECK: %0 = verif.symbolic_value : i42 +// CHECK: %1 = comb.shl %0, %c3_i42 : i42 +// CHECK: %2 = comb.add %0, %1 : i42 +// CHECK: %3 = comb.mul %0, %c9_i42 : i42 +// CHECK: %4 = comb.icmp eq %2, %3 : i42 +// CHECK: verif.assert %4 : i1 +// CHECK: } + +hw.module @Mul9(in %a: i42, out z: i42) { + %c3_i42 = hw.constant 3 : i42 + %c9_i42 = hw.constant 9 : i42 + %0 = comb.shl %a, %c3_i42 : i42 // 8*a + %1 = comb.add %a, %0 : i42 // a + 8*a + %2 = verif.contract %1 : i42 { + %3 = comb.mul %a, %c9_i42 : i42 // 9*a + %4 = comb.icmp eq %2, %3 : i42 // 9*a == a + 8*a + verif.ensure %4 : i1 + } + hw.output %2 : i42 +} + +// CHECK: verif.formal @CarrySaveCompress3to2 { +// CHECK: %c1_i42 = hw.constant 1 : i42 +// CHECK: %0 = verif.symbolic_value : i42 +// CHECK: %1 = verif.symbolic_value : i42 +// CHECK: %2 = verif.symbolic_value : i42 +// CHECK: %3 = comb.xor %0, %1, %2 : i42 +// CHECK: %4 = comb.and %0, %1 : i42 +// CHECK: %5 = comb.or %0, %1 : i42 +// CHECK: %6 = comb.and %5, %2 : i42 +// CHECK: %7 = comb.or %4, %6 : i42 +// CHECK: %8 = comb.shl %7, %c1_i42 : i42 +// CHECK: %9 = comb.add %0, %1, %2 : i42 +// CHECK: %10 = comb.add %3, %8 : i42 +// CHECK: %11 = comb.icmp eq %9, %10 : i42 +// CHECK: verif.assert %11 : i1 +// CHECK: } + +hw.module @CarrySaveCompress3to2( + in %a0: i42, in %a1: i42, in %a2: i42, + out z0: i42, out z1: i42 +) { + %c1_i42 = hw.constant 1 : i42 + %0 = comb.xor %a0, %a1, %a2 : i42 // sum bits of FA (a0^a1^a2) + %1 = comb.and %a0, %a1 : i42 + %2 = comb.or %a0, %a1 : i42 + %3 = comb.and %2, %a2 : i42 + %4 = comb.or %1, %3 : i42 // carry bits of FA (a0&a1 | a2&(a0|a1)) + %5 = comb.shl %4, %c1_i42 : i42 // %5 = carry << 1 + %z0, %z1 = verif.contract %0, %5 : i42, i42 { + %inputSum = comb.add %a0, %a1, %a2 : i42 + %outputSum = comb.add %z0, %z1 : i42 + %6 = comb.icmp eq %inputSum, %outputSum : i42 + verif.ensure %6 : i1 + } + hw.output %z0, %z1 : i42, i42 +} + +// CHECK: verif.formal @ShiftLeft { +// CHECK: %c1_i8 = hw.constant 1 : i8 +// CHECK: %c2_i8 = hw.constant 2 : i8 +// CHECK: %c4_i8 = hw.constant 4 : i8 +// CHECK: %c8_i8 = hw.constant 8 : i8 +// CHECK: %0 = verif.symbolic_value : i8 +// CHECK: %1 = verif.symbolic_value : i8 +// CHECK: %2 = comb.extract %1 from 2 : (i8) -> i1 +// CHECK: %3 = comb.extract %1 from 1 : (i8) -> i1 +// CHECK: %4 = comb.extract %1 from 0 : (i8) -> i1 +// CHECK: %5 = comb.shl %0, %c4_i8 : i8 +// CHECK: %6 = comb.mux %2, %5, %0 : i8 +// CHECK: %7 = comb.shl %6, %c2_i8 : i8 +// CHECK: %8 = comb.mux %3, %7, %6 : i8 +// CHECK: %9 = comb.shl %8, %c1_i8 : i8 +// CHECK: %10 = comb.mux %4, %9, %8 : i8 +// CHECK: %11 = comb.icmp ult %1, %c8_i8 : i8 +// CHECK: %12 = comb.shl %0, %1 : i8 +// CHECK: %13 = comb.icmp eq %10, %12 : i8 +// CHECK: verif.assert %13 : i1 +// CHECK: verif.assume %11 : i1 +// CHECK: } + +hw.module @ShiftLeft(in %a: i8, in %b: i8, out z: i8) { + %c4_i8 = hw.constant 4 : i8 + %c2_i8 = hw.constant 2 : i8 + %c1_i8 = hw.constant 1 : i8 + %b2 = comb.extract %b from 2 : (i8) -> i1 + %b1 = comb.extract %b from 1 : (i8) -> i1 + %b0 = comb.extract %b from 0 : (i8) -> i1 + %0 = comb.shl %a, %c4_i8 : i8 + %1 = comb.mux %b2, %0, %a : i8 + %2 = comb.shl %1, %c2_i8 : i8 + %3 = comb.mux %b1, %2, %1 : i8 + %4 = comb.shl %3, %c1_i8 : i8 + %5 = comb.mux %b0, %4, %3 : i8 + + // Contract to check that the multiplexers and constant shifts above indeed + // produce the correct shift by 0 to 7 places, assuming the shift amount is + // less than 8 (we can't shift a number out). + %z = verif.contract %5 : i8 { + // Shift amount must be less than 8. + %c8_i8 = hw.constant 8 : i8 + %blt8 = comb.icmp ult %b, %c8_i8 : i8 + verif.require %blt8 : i1 + + // In that case the mux tree computes the correct left-shift. + %ashl = comb.shl %a, %b : i8 + %eq = comb.icmp eq %z, %ashl : i8 + verif.ensure %eq : i1 + } + + hw.output %z : i8 +}