Skip to content

Commit

Permalink
[FIRRTL] FoldRegMems: insert new ops into same block as memory (#7909)
Browse files Browse the repository at this point in the history
Before this PR, FoldRegMems would construct new ops in the "body of the parent
FModuleOp". We need to place these ops in the same block as the memory.

This PR fixes a bug where, when a memory under a layerblock was canonicalized
to a register, the register would be placed at the original location of the
memory (under the layerblock), but its readers would be placed outside the
layerblock, resulting in a dominance checking error.
  • Loading branch information
rwy7 authored Nov 27, 2024
1 parent 8283dcb commit 4ca5ab5
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 56 deletions.
40 changes: 23 additions & 17 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2868,7 +2868,9 @@ struct FoldRegMems : public mlir::RewritePattern {
if (hasDontTouch(mem) || info.depth != 1)
return failure();

auto memModule = mem->getParentOfType<FModuleOp>();
auto ty = mem.getDataType();
auto loc = mem.getLoc();
auto *block = mem->getBlock();

// Find the clock of the register-to-be, all write ports should share it.
Value clock;
Expand Down Expand Up @@ -2922,14 +2924,24 @@ struct FoldRegMems : public mlir::RewritePattern {
return failure();
clock = portClock;
}

// Create a new register to store the data.
auto ty = mem.getDataType();
rewriter.setInsertionPointAfterValue(clock);
auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
.getResult();
// Create a new wire where the memory used to be. This wire will dominate
// all readers of the memory. Reads should be made through this wire.
rewriter.setInsertionPointAfter(mem);
auto memWire = rewriter.create<WireOp>(loc, ty).getResult();

// The memory is replaced by a register, which we place at the end of the
// block, so that any value driven to the original memory will dominate the
// new register (including the clock). All other ops will be placed
// after the register.
rewriter.setInsertionPointToEnd(block);
auto memReg =
rewriter.create<RegOp>(loc, ty, clock, mem.getName()).getResult();

// Connect the output of the register to the wire.
rewriter.create<MatchingConnectOp>(loc, memWire, memReg);

// Helper to insert a given number of pipeline stages through registers.
// The pipelines are placed at the end of the block.
auto pipeline = [&](Value value, Value clock, const Twine &name,
unsigned latency) {
for (unsigned i = 0; i < latency; ++i) {
Expand All @@ -2938,7 +2950,6 @@ struct FoldRegMems : public mlir::RewritePattern {
llvm::raw_string_ostream os(regName);
os << mem.getName() << "_" << name << "_" << i;
}

auto reg = rewriter
.create<RegOp>(mem.getLoc(), value.getType(), clock,
rewriter.getStringAttr(regName))
Expand All @@ -2962,7 +2973,6 @@ struct FoldRegMems : public mlir::RewritePattern {
auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
Value value = getPortFieldValue(port, field);
assert(value);
rewriter.setInsertionPointAfterValue(value);
return pipeline(value, portClock, name + "_" + field, stages);
};

Expand All @@ -2974,8 +2984,7 @@ struct FoldRegMems : public mlir::RewritePattern {
// address must be 0 for single-address memories and the enable signal
// is ignored, always reading out the register. Under these constraints,
// the read port can be replaced with the value from the register.
rewriter.setInsertionPointAfterValue(reg);
replacePortField(rewriter, port, "data", reg);
replacePortField(rewriter, port, "data", memWire);
break;
}
case MemOp::PortKind::Write: {
Expand All @@ -2987,16 +2996,14 @@ struct FoldRegMems : public mlir::RewritePattern {
}
case MemOp::PortKind::ReadWrite: {
// Always read the register into the read end.
rewriter.setInsertionPointAfterValue(reg);
replacePortField(rewriter, port, "rdata", reg);
replacePortField(rewriter, port, "rdata", memWire);

// Create a write enable and pipeline stages.
auto wdata = portPipeline("wdata", writeStages);
auto wmask = portPipeline("wmask", writeStages);

Value en = getPortFieldValue(port, "en");
Value wmode = getPortFieldValue(port, "wmode");
rewriter.setInsertionPointToEnd(memModule.getBodyBlock());

auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
auto wenPipelined =
Expand All @@ -3008,8 +3015,7 @@ struct FoldRegMems : public mlir::RewritePattern {
}

// Regardless of `writeUnderWrite`, always implement PortOrder.
rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
Value next = reg;
Value next = memReg;
for (auto &[data, en, mask] : writes) {
Value masked;

Expand All @@ -3035,7 +3041,7 @@ struct FoldRegMems : public mlir::RewritePattern {

next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
}
rewriter.create<MatchingConnectOp>(reg.getLoc(), reg, next);
rewriter.create<MatchingConnectOp>(memReg.getLoc(), memReg, next);

// Delete the fields and their associated connects.
for (Operation *conn : connects)
Expand Down
150 changes: 111 additions & 39 deletions test/Dialect/FIRRTL/simplify-mems.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,11 @@ firrtl.circuit "OneAddressMasked" {
} :
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<2>>
// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>

// CHECK: firrtl.matchingconnect %result_read, %Memory : !firrtl.uint<32>
// CHECK: [[MemoryWire:%.+]] = firrtl.wire : !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %result_read, [[MemoryWire]] : !firrtl.uint<32>
// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect [[MemoryWire]], %Memory : !firrtl.uint<32>

%read_addr = firrtl.subfield %Memory_read[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %read_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
Expand Down Expand Up @@ -408,31 +410,6 @@ firrtl.circuit "OneAddressNoMask" {
in %in_rwen: !firrtl.uint<1>,
out %result_read: !firrtl.uint<32>,
out %result_rw: !firrtl.uint<32>) {

// Pipeline the inputs.
// TODO: It would be good to de-duplicate these either in the pass or in a canonicalizer.

// CHECK: %Memory_write_en_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_0, %in_wen : !firrtl.uint<1>
// CHECK: %Memory_write_en_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_1, %Memory_write_en_0 : !firrtl.uint<1>
// CHECK: %Memory_write_en_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_2, %Memory_write_en_1 : !firrtl.uint<1>

// CHECK: %Memory_write_data_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_0, %in_data : !firrtl.uint<32>
// CHECK: %Memory_write_data_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_1, %Memory_write_data_0 : !firrtl.uint<32>
// CHECK: %Memory_write_data_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_2, %Memory_write_data_1 : !firrtl.uint<32>

// CHECK: %Memory_rw_wdata_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_0, %in_data : !firrtl.uint<32>
// CHECK: %Memory_rw_wdata_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_1, %Memory_rw_wdata_0 : !firrtl.uint<32>
// CHECK: %Memory_rw_wdata_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_2, %Memory_rw_wdata_1 : !firrtl.uint<32>

%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>

%Memory_read, %Memory_rw, %Memory_write = firrtl.mem Undefined
Expand All @@ -447,9 +424,56 @@ firrtl.circuit "OneAddressNoMask" {
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>

// A wire, holding the value of the memory, goes to the front of the block.
// CHECK: [[MemoryWire:%.+]] = firrtl.wire : !firrtl.uint<32>

// The original uses of the memory are replaced with uses of the wire.
// CHECK: firrtl.matchingconnect %result_read, [[MemoryWire]] : !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %result_rw, [[MemoryWire]] : !firrtl.uint<32>

// The memory is replaced by a register at the end of the block
// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>

// CHECK: firrtl.matchingconnect %result_read, %Memory : !firrtl.uint<32>
// The register's data is written to the MemoryWire
// CHECK: firrtl.matchingconnect [[MemoryWire]], %Memory : !firrtl.uint<32>

// Following the register, we pipeline the inputs.
// TODO: It would be good to de-duplicate these either in the pass or in a canonicalizer.

// CHECK: %Memory_rw_wdata_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_0, %in_data : !firrtl.uint<32>
// CHECK: %Memory_rw_wdata_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_1, %Memory_rw_wdata_0 : !firrtl.uint<32>
// CHECK: %Memory_rw_wdata_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_2, %Memory_rw_wdata_1 : !firrtl.uint<32>

// CHECK: [[WRITING:%.+]] = firrtl.and %in_rwen, %wmode_rw : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: %Memory_rw_wen_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_0, [[WRITING]] : !firrtl.uint<1>
// CHECK: %Memory_rw_wen_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_1, %Memory_rw_wen_0 : !firrtl.uint<1>
// CHECK: %Memory_rw_wen_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_2, %Memory_rw_wen_1 : !firrtl.uint<1>

// CHECK: %Memory_write_data_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_0, %in_data : !firrtl.uint<32>
// CHECK: %Memory_write_data_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_1, %Memory_write_data_0 : !firrtl.uint<32>
// CHECK: %Memory_write_data_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_2, %Memory_write_data_1 : !firrtl.uint<32>

// CHECK: %Memory_write_en_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_0, %in_wen : !firrtl.uint<1>
// CHECK: %Memory_write_en_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_1, %Memory_write_en_0 : !firrtl.uint<1>
// CHECK: %Memory_write_en_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_2, %Memory_write_en_1 : !firrtl.uint<1>

// Finally, the pipelined inputs are driven to the register.
// CHECK: [[WRITE_RW:%.+]] = firrtl.mux(%Memory_rw_wen_2, %Memory_rw_wdata_2, %Memory) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: [[WRITE_W:%.+]] = firrtl.mux(%Memory_write_en_2, %Memory_write_data_2, [[WRITE_RW]]) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory, [[WRITE_W]] : !firrtl.uint<32>

%read_addr = firrtl.subfield %Memory_read[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %read_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%read_en = firrtl.subfield %Memory_read[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
Expand All @@ -459,7 +483,6 @@ firrtl.circuit "OneAddressNoMask" {
%read_data = firrtl.subfield %Memory_read[data] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %result_read, %read_data : !firrtl.uint<32>, !firrtl.uint<32>

// CHECK: firrtl.matchingconnect %result_rw, %Memory : !firrtl.uint<32>
%rw_addr = firrtl.subfield %Memory_rw[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%rw_en = firrtl.subfield %Memory_rw[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
Expand All @@ -475,16 +498,7 @@ firrtl.circuit "OneAddressNoMask" {
%rw_wmask = firrtl.subfield %Memory_rw[wmask] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wmask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>

// CHECK: [[WRITING:%.+]] = firrtl.and %in_rwen, %wmode_rw : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: %Memory_rw_wen_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_0, [[WRITING]] : !firrtl.uint<1>
// CHECK: %Memory_rw_wen_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_1, %Memory_rw_wen_0 : !firrtl.uint<1>
// CHECK: %Memory_rw_wen_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_2, %Memory_rw_wen_1 : !firrtl.uint<1>
// CHECK: [[WRITE_RW:%.+]] = firrtl.mux(%Memory_rw_wen_2, %Memory_rw_wdata_2, %Memory) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: [[WRITE_W:%.+]] = firrtl.mux(%Memory_write_en_2, %Memory_write_data_2, [[WRITE_RW]]) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory, [[WRITE_W]] : !firrtl.uint<32>

%write_addr = firrtl.subfield %Memory_write[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
firrtl.connect %write_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%write_en = firrtl.subfield %Memory_write[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
Expand All @@ -497,3 +511,61 @@ firrtl.circuit "OneAddressNoMask" {
firrtl.connect %write_mask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
}
}

// -----

// This test ensures that the FoldRegMems canonicalization correctly
// folds memories under layerblocks.
firrtl.circuit "Rewrite1ElementMemoryToRegisterUnderLayerblock" {
firrtl.layer @A bind {}

firrtl.module public @Rewrite1ElementMemoryToRegisterUnderLayerblock(
in %clock: !firrtl.clock,
in %addr: !firrtl.uint<1>,
in %in_data: !firrtl.uint<32>,
in %wmode_rw: !firrtl.uint<1>,
in %in_wen: !firrtl.uint<1>,
in %in_rwen: !firrtl.uint<1>) {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>

// CHECK firrtl.layerblock @A
firrtl.layerblock @A {
// CHECK: %result_read = firrtl.wire : !firrtl.uint<32>
// CHECK: %result_rw = firrtl.wire : !firrtl.uint<32>
%result_read = firrtl.wire : !firrtl.uint<32>
%result_rw = firrtl.wire : !firrtl.uint<32>

// CHECK: [[MemoryWire:%.+]] = firrtl.wire : !firrtl.uint<32>
%Memory_rw = firrtl.mem Undefined
{
depth = 1 : i64,
name = "Memory",
portNames = ["rw"],
readLatency = 2 : i32,
writeLatency = 2 : i32
} : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>

%rw_addr = firrtl.subfield %Memory_rw[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%rw_en = firrtl.subfield %Memory_rw[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_en, %in_rwen : !firrtl.uint<1>, !firrtl.uint<1>
%rw_clk = firrtl.subfield %Memory_rw[clk] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_clk, %clock : !firrtl.clock, !firrtl.clock
%rw_rdata = firrtl.subfield %Memory_rw[rdata] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>

%rw_wmode = firrtl.subfield %Memory_rw[wmode] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wmode, %wmode_rw : !firrtl.uint<1>, !firrtl.uint<1>
%rw_wdata = firrtl.subfield %Memory_rw[wdata] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wdata, %in_data : !firrtl.uint<32>, !firrtl.uint<32>
%rw_wmask = firrtl.subfield %Memory_rw[wmask] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wmask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>

// CHECK: firrtl.matchingconnect %result_rw, [[MemoryWire]] : !firrtl.uint<32>
firrtl.connect %result_rw, %rw_rdata : !firrtl.uint<32>, !firrtl.uint<32>

// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect [[MemoryWire]], %Memory
// CHECK: firrtl.matchingconnect %Memory, {{%.+}} : !firrtl.uint<32>
}
}
}

0 comments on commit 4ca5ab5

Please sign in to comment.