Skip to content

Commit

Permalink
[Calyx] Lower SCF parallel op to Calyx (#7830)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 authored Nov 18, 2024
1 parent de8ce29 commit 026976a
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 11 deletions.
164 changes: 153 additions & 11 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,15 @@ struct CallScheduleable {
func::CallOp callOp;
};

struct ParScheduleable {
/// Parallel operation to schedule.
scf::ParallelOp parOp;
};

/// A variant of types representing scheduleable operations.
using Scheduleable =
std::variant<calyx::GroupOp, WhileScheduleable, ForScheduleable,
IfScheduleable, CallScheduleable>;
IfScheduleable, CallScheduleable, ParScheduleable>;

class IfLoweringStateInterface {
public:
Expand Down Expand Up @@ -275,6 +280,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
.template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
/// SCF
scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
scf::ParallelOp, scf::ReduceOp,
/// memref
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
memref::StoreOp,
Expand Down Expand Up @@ -338,6 +344,10 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::ForOp forOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::IfOp ifOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
scf::ReduceOp reduceOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
scf::ParallelOp parallelOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp) const;

/// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
Expand Down Expand Up @@ -1093,6 +1103,21 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::ReduceOp reduceOp) const {
// we don't handle reduce operation and simply return success for now since
// BuildParGroups would have already emitted an error and exited early
// if a reduce operation was encountered.
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::ParallelOp parOp) const {
getState<ComponentLoweringState>().addBlockScheduleable(
parOp.getOperation()->getBlock(), ParScheduleable{parOp});
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CallOp callOp) const {
std::string instanceName = calyx::getInstanceName(callOp);
Expand Down Expand Up @@ -1481,6 +1506,106 @@ class BuildIfGroups : public calyx::FuncOpPartialLoweringPattern {
}
};

class BuildParGroups : public calyx::FuncOpPartialLoweringPattern {
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
partiallyLowerFuncToComp(FuncOp funcOp,
PatternRewriter &rewriter) const override {
WalkResult walkResult = funcOp.walk([&](scf::ParallelOp scfParOp) {
if (!scfParOp.getResults().empty()) {
scfParOp.emitError(
"Reduce operations in scf.parallel is not supported yet");
return WalkResult::interrupt();
}

if (failed(partialEval(rewriter, scfParOp)))
return WalkResult::interrupt();

return WalkResult::advance();
});

return walkResult.wasInterrupted() ? failure() : success();
}

private:
// Partially evaluate/pre-compute all blocks being executed in parallel by
// statically generate loop indices combinations
LogicalResult partialEval(PatternRewriter &rewriter,
scf::ParallelOp scfParOp) const {
assert(scfParOp.getLoopSteps() && "Parallel loop must have steps");
auto *body = scfParOp.getBody();
auto parOpIVs = scfParOp.getInductionVars();
auto steps = scfParOp.getStep();
auto lowerBounds = scfParOp.getLowerBound();
auto upperBounds = scfParOp.getUpperBound();
rewriter.setInsertionPointAfter(scfParOp);
scf::ParallelOp newParOp = scfParOp.cloneWithoutRegions();
auto loc = newParOp.getLoc();
rewriter.insert(newParOp);
OpBuilder insideBuilder(newParOp);
Block *currBlock = nullptr;
auto &region = newParOp.getRegion();
IRMapping operandMap;

// extract lower bounds, upper bounds, and steps as integer index values
SmallVector<int64_t> lbVals, ubVals, stepVals;
for (auto lb : lowerBounds) {
auto lbOp = lb.getDefiningOp<arith::ConstantIndexOp>();
assert(lbOp &&
"Lower bound must be a statically computable constant index");
lbVals.push_back(lbOp.value());
}
for (auto ub : upperBounds) {
auto ubOp = ub.getDefiningOp<arith::ConstantIndexOp>();
assert(ubOp &&
"Upper bound must be a statically computable constant index");
ubVals.push_back(ubOp.value());
}
for (auto step : steps) {
auto stepOp = step.getDefiningOp<arith::ConstantIndexOp>();
assert(stepOp && "Step must be a statically computable constant index");
stepVals.push_back(stepOp.value());
}

// Initialize indices with lower bounds
SmallVector<int64_t> indices = lbVals;

while (true) {
// Create a new block in the region for the current combination of indices
currBlock = &region.emplaceBlock();
insideBuilder.setInsertionPointToEnd(currBlock);

// Map induction variables to constant indices
for (unsigned i = 0; i < indices.size(); ++i) {
Value ivConstant =
insideBuilder.create<arith::ConstantIndexOp>(loc, indices[i]);
operandMap.map(parOpIVs[i], ivConstant);
}

for (auto it = body->begin(); it != std::prev(body->end()); ++it)
insideBuilder.clone(*it, operandMap);

// Increment indices using `step`
bool done = false;
for (int dim = indices.size() - 1; dim >= 0; --dim) {
indices[dim] += stepVals[dim];
if (indices[dim] < ubVals[dim])
break;
indices[dim] = lbVals[dim];
if (dim == 0)
// All combinations have been generated
done = true;
}
if (done)
break;
}

rewriter.replaceOp(scfParOp, newParOp);
return success();
}
};

/// Builds a control schedule by traversing the CFG of the function and
/// associating this with the previously created groups.
/// For simplicity, the generated control flow is expanded for all possible
Expand Down Expand Up @@ -1512,7 +1637,8 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
getState<ComponentLoweringState>().getBlockScheduleables(block);
auto loc = block->front().getLoc();

if (compBlockScheduleables.size() > 1) {
if (compBlockScheduleables.size() > 1 &&
!isa<scf::ParallelOp>(block->getParentOp())) {
auto seqOp = rewriter.create<calyx::SeqOp>(loc);
parentCtrlBlock = seqOp.getBodyBlock();
}
Expand All @@ -1537,18 +1663,30 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {

/// Only schedule the 'after' block. The 'before' block is
/// implicitly scheduled when evaluating the while condition.
LogicalResult res = buildCFGControl(path, rewriter, whileBodyOpBlock,
block, whileOp.getBodyBlock());
if (LogicalResult result =
buildCFGControl(path, rewriter, whileBodyOpBlock, block,
whileOp.getBodyBlock());
result.failed())
return result;

// Insert loop-latch at the end of the while group
rewriter.setInsertionPointToEnd(whileBodyOpBlock);
calyx::GroupOp whileLatchGroup =
getState<ComponentLoweringState>().getWhileLoopLatchGroup(whileOp);
rewriter.create<calyx::EnableOp>(whileLatchGroup.getLoc(),
whileLatchGroup.getName());

if (res.failed())
return res;
} else if (auto *parSchedPtr = std::get_if<ParScheduleable>(&group)) {
auto parOp = parSchedPtr->parOp;
auto calyxParOp = rewriter.create<calyx::ParOp>(parOp.getLoc());
for (auto &innerBlock : parOp.getRegion().getBlocks()) {
rewriter.setInsertionPointToEnd(calyxParOp.getBodyBlock());
auto seqOp = rewriter.create<calyx::SeqOp>(parOp.getLoc());
rewriter.setInsertionPointToEnd(seqOp.getBodyBlock());
if (LogicalResult res = scheduleBasicBlock(
rewriter, path, seqOp.getBodyBlock(), &innerBlock);
res.failed())
return res;
}
} else if (auto *forSchedPtr = std::get_if<ForScheduleable>(&group);
forSchedPtr) {
auto forOp = forSchedPtr->forOp;
Expand All @@ -1563,17 +1701,17 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
auto *forBodyOpBlock = forBodyOp.getBodyBlock();

// Schedule the body of the for loop.
LogicalResult res = buildCFGControl(path, rewriter, forBodyOpBlock,
block, forOp.getBodyBlock());
if (LogicalResult res = buildCFGControl(path, rewriter, forBodyOpBlock,
block, forOp.getBodyBlock());
res.failed())
return res;

// Insert loop-latch at the end of the while group.
rewriter.setInsertionPointToEnd(forBodyOpBlock);
calyx::GroupOp forLatchGroup =
getState<ComponentLoweringState>().getForLoopLatchGroup(forOp);
rewriter.create<calyx::EnableOp>(forLatchGroup.getLoc(),
forLatchGroup.getName());
if (res.failed())
return res;
} else if (auto *ifSchedPtr = std::get_if<IfScheduleable>(&group);
ifSchedPtr) {
auto ifOp = ifSchedPtr->ifOp;
Expand Down Expand Up @@ -2241,6 +2379,9 @@ void SCFToCalyxPass::runOnOperation() {
/// This pass inlines scf.ExecuteRegionOp's by adding control-flow.
addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);

addOncePattern<BuildParGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

/// This pattern converts all index typed values to an i32 integer.
addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
funcMap, *loweringState);
Expand Down Expand Up @@ -2270,6 +2411,7 @@ void SCFToCalyxPass::runOnOperation() {

addOncePattern<BuildIfGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

/// This pattern converts operations within basic blocks to Calyx library
/// operators. Combinational operations are assigned inside a
/// calyx::CombGroupOp, and sequential inside calyx::GroupOps.
Expand Down
119 changes: 119 additions & 0 deletions test/Conversion/SCFToCalyx/convert_simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,122 @@ module {
return %1 : f32
}
}

// -----

// Test parallel op lowering

// CHECK: calyx.wires {
// CHECK-DAG: calyx.group @bb0_0 {
// CHECK-DAG: calyx.assign %std_slice_7.in = %c0_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_7.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
// CHECK-DAG: calyx.assign %mem_1.write_en = %false : i1
// CHECK-DAG: calyx.assign %load_0_reg.in = %mem_1.read_data : i32
// CHECK-DAG: calyx.assign %load_0_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_0_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb0_1 {
// CHECK-DAG: calyx.assign %std_slice_6.in = %c0_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_6.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_0_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb1_0 {
// CHECK-DAG: calyx.assign %std_slice_5.in = %c4_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_5.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
// CHECK-DAG: calyx.assign %mem_1.write_en = %false : i1
// CHECK-DAG: calyx.assign %load_1_reg.in = %mem_1.read_data : i32
// CHECK-DAG: calyx.assign %load_1_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_1_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb1_1 {
// CHECK-DAG: calyx.assign %std_slice_4.in = %c1_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_4.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_1_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb2_0 {
// CHECK-DAG: calyx.assign %std_slice_3.in = %c2_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_3.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
// CHECK-DAG: calyx.assign %mem_1.write_en = %false : i1
// CHECK-DAG: calyx.assign %load_2_reg.in = %mem_1.read_data : i32
// CHECK-DAG: calyx.assign %load_2_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_2_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb2_1 {
// CHECK-DAG: calyx.assign %std_slice_2.in = %c4_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_2.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_2_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb3_0 {
// CHECK-DAG: calyx.assign %std_slice_1.in = %c6_i32 : i32
// CHECK-DAG: calyx.assign %mem_1.addr0 = %std_slice_1.out : i3
// CHECK-DAG: calyx.assign %mem_1.content_en = %true : i1
// CHECK-DAG: calyx.assign %mem_1.write_en = %false : i1
// CHECK-DAG: calyx.assign %load_3_reg.in = %mem_1.read_data : i32
// CHECK-DAG: calyx.assign %load_3_reg.write_en = %mem_1.done : i1
// CHECK-DAG: calyx.group_done %load_3_reg.done : i1
// CHECK-DAG: }
// CHECK-DAG: calyx.group @bb3_1 {
// CHECK-DAG: calyx.assign %std_slice_0.in = %c5_i32 : i32
// CHECK-DAG: calyx.assign %mem_0.addr0 = %std_slice_0.out : i3
// CHECK-DAG: calyx.assign %mem_0.write_data = %load_3_reg.out : i32
// CHECK-DAG: calyx.assign %mem_0.write_en = %true : i1
// CHECK-DAG: calyx.assign %mem_0.content_en = %true : i1
// CHECK-DAG: calyx.group_done %mem_0.done : i1
// CHECK-DAG: }
// CHECK-DAG: }
// CHECK-DAG: calyx.control {
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.par {
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb0_0
// CHECK-DAG: calyx.enable @bb0_1
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb1_0
// CHECK-DAG: calyx.enable @bb1_1
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb2_0
// CHECK-DAG: calyx.enable @bb2_1
// CHECK-DAG: }
// CHECK-DAG: calyx.seq {
// CHECK-DAG: calyx.enable @bb3_0
// CHECK-DAG: calyx.enable @bb3_1
// CHECK-DAG: }
// CHECK-DAG: }
// CHECK-DAG: }
// CHECK-DAG: }

module {
func.func @main() {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<6xi32>
%alloc_1 = memref.alloc() : memref<6xi32>
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c3, %c2) step (%c2, %c1) {
%4 = arith.shli %arg3, %c2 : index
%5 = arith.addi %4, %arg2 : index
%6 = memref.load %alloc_1[%5] : memref<6xi32>
%7 = arith.shli %arg2, %c1 : index
%8 = arith.addi %7, %arg3 : index
memref.store %6, %alloc[%8] : memref<6xi32>
scf.reduce
}
return
}
}

22 changes: 22 additions & 0 deletions test/Conversion/SCFToCalyx/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,25 @@ module {
}
}

// -----

module {
func.func @main() -> i32 {
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%cinit = arith.constant 0 : i32
%alloc = memref.alloc() : memref<6xi32>
// expected-error @+1 {{Reduce operations in scf.parallel is not supported yet}}
%r:1 = scf.parallel (%arg2) = (%c0) to (%c3) step (%c1) init (%cinit) -> i32 {
%6 = memref.load %alloc[%arg2] : memref<6xi32>
scf.reduce(%6 : i32) {
^bb0(%lhs : i32, %rhs: i32):
%res = arith.addi %lhs, %rhs : i32
scf.reduce.return %res : i32
}
}
return %r : i32
}
}

0 comments on commit 026976a

Please sign in to comment.