Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Verif] LowerContractsPass #7870

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/circt/Dialect/Verif/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,12 @@ def LowerFormalToHWPass : Pass<"lower-formal-to-hw", "mlir::ModuleOp"> {
}];
}

def LowerContractsPass : Pass<"lower-contracts", "mlir::ModuleOp"> {
let summary = "Lower contracts into formal tests";
let description = [{
Converts `hw.module` ops containing a `verif.contract` into a
`verif.formal` op.
}];
}

#endif // CIRCT_DIALECT_VERIF_PASSES_TD
1 change: 1 addition & 0 deletions lib/Dialect/Verif/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_circt_dialect_library(CIRCTVerifTransforms
VerifyClockedAssertLike.cpp
PrepareForFormal.cpp
LowerFormalToHW.cpp
LowerContracts.cpp

DEPENDS
CIRCTVerifTransformsIncGen
Expand Down
109 changes: 109 additions & 0 deletions lib/Dialect/Verif/Transforms/LowerContracts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//===- LowerContracts.cpp - Formal Preparations --*- C++ -*----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
//
// Lower contracts into verif.formal tests.
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Verif/VerifOps.h"
#include "circt/Dialect/Verif/VerifPasses.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace circt;

namespace circt {
namespace verif {
#define GEN_PASS_DEF_LOWERCONTRACTSPASS
#include "circt/Dialect/Verif/Passes.h.inc"
} // namespace verif
} // namespace circt

using namespace mlir;
using namespace verif;
using namespace hw;

namespace {
struct LowerContractsPass
: verif::impl::LowerContractsPassBase<LowerContractsPass> {
void runOnOperation() override;
};

template <typename FROM, typename TO>
void replaceContractOp(PatternRewriter &rewriter, Block *body) {
for (auto op : llvm::make_early_inc_range(body->getOps<FROM>())) {
auto label = op.getLabel();
StringAttr labelAttr;
if (label) {
labelAttr = rewriter.getStringAttr(label.value());
}
rewriter.replaceOpWithNewOp<TO>(op, op.getProperty(), op.getEnable(),
labelAttr);
}
}

struct HWOpRewritePattern : public OpRewritePattern<HWModuleOp> {
using OpRewritePattern<HWModuleOp>::OpRewritePattern;

LogicalResult matchAndRewrite(HWModuleOp op,
PatternRewriter &rewriter) const override {
auto formalOp = rewriter.create<verif::FormalOp>(
op.getLoc(), op.getNameAttr(), rewriter.getDictionaryAttr({}));

// Clone module body into fomal op body
rewriter.cloneRegionBefore(op.getRegion(), formalOp.getBody(),
formalOp.getBody().end());

auto *bodyBlock = &formalOp.getBody().front();

// Erase hw.output
rewriter.eraseOp(bodyBlock->getTerminator());

// Convert block args to symbolic values
rewriter.setInsertionPointToStart(bodyBlock);
for (auto arg : llvm::make_early_inc_range(bodyBlock->getArguments())) {
auto sym =
rewriter.create<verif::SymbolicValueOp>(arg.getLoc(), arg.getType());
rewriter.replaceAllUsesWith(arg, sym);
}
bodyBlock->eraseArguments(0, bodyBlock->getNumArguments());

// Inline contract ops
for (auto contractOp :
llvm::make_early_inc_range(bodyBlock->getOps<verif::ContractOp>())) {

// Convert ensure to assert, require to assume
rewriter.setInsertionPointToEnd(&contractOp.getBody().front());
Block *contractBlock = &contractOp.getBody().front();
replaceContractOp<EnsureOp, AssertOp>(rewriter, contractBlock);
replaceContractOp<RequireOp, AssumeOp>(rewriter, contractBlock);

// Inline body
rewriter.inlineBlockBefore(&contractOp.getBody().front(),
&formalOp.getBody().front(),
formalOp.getBody().front().end());

// Replace results with inputs and erase
for (auto [input, result] :
llvm::zip(contractOp.getResults(), contractOp.getInputs())) {
rewriter.replaceAllUsesWith(input, result);
}
rewriter.eraseOp(contractOp);
}

rewriter.eraseOp(op);
return success();
}
};
} // namespace

void LowerContractsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<HWOpRewritePattern>(patterns.getContext());

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
2 changes: 1 addition & 1 deletion llvm
Submodule llvm updated 8433 files
118 changes: 118 additions & 0 deletions test/Dialect/Verif/lower-contracts.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// RUN: circt-opt --lower-contracts %s | FileCheck %s


// CHECK: verif.formal @Mul9 {
// CHECK: %c9_i42 = hw.constant 9 : i42
// CHECK: %c3_i42 = hw.constant 3 : i42
// CHECK: %0 = verif.symbolic_value : i42
// CHECK: %1 = comb.shl %0, %c3_i42 : i42
// CHECK: %2 = comb.add %0, %1 : i42
// CHECK: %3 = comb.mul %0, %c9_i42 : i42
// CHECK: %4 = comb.icmp eq %2, %3 : i42
// CHECK: verif.assert %4 : i1
// CHECK: }

hw.module @Mul9(in %a: i42, out z: i42) {
%c3_i42 = hw.constant 3 : i42
%c9_i42 = hw.constant 9 : i42
%0 = comb.shl %a, %c3_i42 : i42 // 8*a
%1 = comb.add %a, %0 : i42 // a + 8*a
%2 = verif.contract %1 : i42 {
%3 = comb.mul %a, %c9_i42 : i42 // 9*a
%4 = comb.icmp eq %2, %3 : i42 // 9*a == a + 8*a
verif.ensure %4 : i1
}
hw.output %2 : i42
}

// CHECK: verif.formal @CarrySaveCompress3to2 {
// CHECK: %c1_i42 = hw.constant 1 : i42
// CHECK: %0 = verif.symbolic_value : i42
// CHECK: %1 = verif.symbolic_value : i42
// CHECK: %2 = verif.symbolic_value : i42
// CHECK: %3 = comb.xor %0, %1, %2 : i42
// CHECK: %4 = comb.and %0, %1 : i42
// CHECK: %5 = comb.or %0, %1 : i42
// CHECK: %6 = comb.and %5, %2 : i42
// CHECK: %7 = comb.or %4, %6 : i42
// CHECK: %8 = comb.shl %7, %c1_i42 : i42
// CHECK: %9 = comb.add %0, %1, %2 : i42
// CHECK: %10 = comb.add %3, %8 : i42
// CHECK: %11 = comb.icmp eq %9, %10 : i42
// CHECK: verif.assert %11 : i1
// CHECK: }

hw.module @CarrySaveCompress3to2(
in %a0: i42, in %a1: i42, in %a2: i42,
out z0: i42, out z1: i42
) {
%c1_i42 = hw.constant 1 : i42
%0 = comb.xor %a0, %a1, %a2 : i42 // sum bits of FA (a0^a1^a2)
%1 = comb.and %a0, %a1 : i42
%2 = comb.or %a0, %a1 : i42
%3 = comb.and %2, %a2 : i42
%4 = comb.or %1, %3 : i42 // carry bits of FA (a0&a1 | a2&(a0|a1))
%5 = comb.shl %4, %c1_i42 : i42 // %5 = carry << 1
%z0, %z1 = verif.contract %0, %5 : i42, i42 {
%inputSum = comb.add %a0, %a1, %a2 : i42
%outputSum = comb.add %z0, %z1 : i42
%6 = comb.icmp eq %inputSum, %outputSum : i42
verif.ensure %6 : i1
}
hw.output %z0, %z1 : i42, i42
}

// CHECK: verif.formal @ShiftLeft {
// CHECK: %c1_i8 = hw.constant 1 : i8
// CHECK: %c2_i8 = hw.constant 2 : i8
// CHECK: %c4_i8 = hw.constant 4 : i8
// CHECK: %c8_i8 = hw.constant 8 : i8
// CHECK: %0 = verif.symbolic_value : i8
// CHECK: %1 = verif.symbolic_value : i8
// CHECK: %2 = comb.extract %1 from 2 : (i8) -> i1
// CHECK: %3 = comb.extract %1 from 1 : (i8) -> i1
// CHECK: %4 = comb.extract %1 from 0 : (i8) -> i1
// CHECK: %5 = comb.shl %0, %c4_i8 : i8
// CHECK: %6 = comb.mux %2, %5, %0 : i8
// CHECK: %7 = comb.shl %6, %c2_i8 : i8
// CHECK: %8 = comb.mux %3, %7, %6 : i8
// CHECK: %9 = comb.shl %8, %c1_i8 : i8
// CHECK: %10 = comb.mux %4, %9, %8 : i8
// CHECK: %11 = comb.icmp ult %1, %c8_i8 : i8
// CHECK: %12 = comb.shl %0, %1 : i8
// CHECK: %13 = comb.icmp eq %10, %12 : i8
// CHECK: verif.assert %13 : i1
// CHECK: verif.assume %11 : i1
// CHECK: }

hw.module @ShiftLeft(in %a: i8, in %b: i8, out z: i8) {
%c4_i8 = hw.constant 4 : i8
%c2_i8 = hw.constant 2 : i8
%c1_i8 = hw.constant 1 : i8
%b2 = comb.extract %b from 2 : (i8) -> i1
%b1 = comb.extract %b from 1 : (i8) -> i1
%b0 = comb.extract %b from 0 : (i8) -> i1
%0 = comb.shl %a, %c4_i8 : i8
%1 = comb.mux %b2, %0, %a : i8
%2 = comb.shl %1, %c2_i8 : i8
%3 = comb.mux %b1, %2, %1 : i8
%4 = comb.shl %3, %c1_i8 : i8
%5 = comb.mux %b0, %4, %3 : i8

// Contract to check that the multiplexers and constant shifts above indeed
// produce the correct shift by 0 to 7 places, assuming the shift amount is
// less than 8 (we can't shift a number out).
%z = verif.contract %5 : i8 {
// Shift amount must be less than 8.
%c8_i8 = hw.constant 8 : i8
%blt8 = comb.icmp ult %b, %c8_i8 : i8
verif.require %blt8 : i1

// In that case the mux tree computes the correct left-shift.
%ashl = comb.shl %a, %b : i8
%eq = comb.icmp eq %z, %ashl : i8
verif.ensure %eq : i1
}

hw.output %z : i8
}
Loading