Skip to content

Commit

Permalink
support if op when its condition check is not combinational
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 committed Oct 9, 2024
1 parent 9f77622 commit b42c4a0
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 92 deletions.
137 changes: 125 additions & 12 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ using Scheduleable =

class IfLoweringStateInterface {
public:
void setCondReg(scf::IfOp op, calyx::RegisterOp regOp) {
Operation *operation = op.getOperation();
assert(condReg.count(operation) == 0 &&
"A condition register was already set for this scf::IfOp!\n");
condReg[operation] = regOp;
}

calyx::RegisterOp getCondReg(scf::IfOp op) {
auto it = condReg.find(op.getOperation());
if (it != condReg.end())
return it->second;
return nullptr;
}

void setThenGroup(scf::IfOp op, calyx::GroupOp group) {
Operation *operation = op.getOperation();
assert(thenGroup.count(operation) == 0 &&
Expand Down Expand Up @@ -172,6 +186,7 @@ class IfLoweringStateInterface {
}

private:
DenseMap<Operation *, calyx::RegisterOp> condReg;
DenseMap<Operation *, calyx::GroupOp> thenGroup;
DenseMap<Operation *, calyx::GroupOp> elseGroup;
DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>> resultRegs;
Expand Down Expand Up @@ -240,13 +255,36 @@ class ForLoopLoweringStateInterface
}
};

class PipeOpLoweringStateInterface {
public:
void setPipeResReg(Operation *op, calyx::RegisterOp reg) {
assert(isa<calyx::MultPipeLibOp>(op) || isa<calyx::DivUPipeLibOp>(op) ||
isa<calyx::DivSPipeLibOp>(op) || isa<calyx::RemUPipeLibOp>(op) ||
isa<calyx::RemSPipeLibOp>(op));
assert(resultRegs.count(op) == 0 &&
"A register was already set for this pipe operation!\n");
resultRegs[op] = reg;
}
// Get the register for a specific pipe operation
calyx::RegisterOp getPipeResReg(Operation *op) {
auto it = resultRegs.find(op);
assert(it != resultRegs.end() &&
"No register was set for this pipe operation!\n");
return it->second;
}

private:
DenseMap<Operation *, calyx::RegisterOp> resultRegs;
};

/// Handles the current state of lowering of a Calyx component. It is mainly
/// used as a key/value store for recording information during partial lowering,
/// which is required at later lowering passes.
class ComponentLoweringState : public calyx::ComponentLoweringStateInterface,
public WhileLoopLoweringStateInterface,
public ForLoopLoweringStateInterface,
public IfLoweringStateInterface,
public PipeOpLoweringStateInterface,
public calyx::SchedulerInterface<Scheduleable> {
public:
ComponentLoweringState(calyx::ComponentOp component)
Expand Down Expand Up @@ -339,7 +377,12 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
/// source operation TSrcOp.
template <typename TGroupOp, typename TCalyxLibOp, typename TSrcOp>
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op,
TypeRange srcTypes, TypeRange dstTypes) const {
TypeRange srcTypes, TypeRange dstTypes,
calyx::RegisterOp srcReg = nullptr,
calyx::RegisterOp dstReg = nullptr) const {
assert((srcReg && dstReg) || (!srcReg && !dstReg));
bool isSequential = srcReg && dstReg;

SmallVector<Type> types;
llvm::append_range(types, srcTypes);
llvm::append_range(types, dstTypes);
Expand All @@ -365,26 +408,54 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {

/// Create assignments to the inputs of the library op.
auto group = createGroupForOp<TGroupOp>(rewriter, op);

if (isSequential) {
auto groupOp = cast<calyx::GroupOp>(group);
getState<ComponentLoweringState>().addBlockScheduleable(op->getBlock(),
groupOp);
}

rewriter.setInsertionPointToEnd(group.getBodyBlock());
for (auto dstOp : enumerate(opInputPorts))
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
op->getOperand(dstOp.index()));

for (auto dstOp : enumerate(opInputPorts)) {
if (isSequential)
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
srcReg.getOut());
else
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
op->getOperand(dstOp.index()));
}

/// Replace the result values of the source operator with the new operator.
for (auto res : enumerate(opOutputPorts)) {
getState<ComponentLoweringState>().registerEvaluatingGroup(res.value(),
group);
op->getResult(res.index()).replaceAllUsesWith(res.value());
if (isSequential)
op->getResult(res.index()).replaceAllUsesWith(dstReg.getOut());
else
op->getResult(res.index()).replaceAllUsesWith(res.value());
}

if (isSequential) {
auto groupOp = cast<calyx::GroupOp>(group);
buildAssignmentsForRegisterWrite(
rewriter, groupOp,
getState<ComponentLoweringState>().getComponentOp(), dstReg,
calyxOp.getOut());
}

return success();
}

/// buildLibraryOp which provides in- and output types based on the operands
/// and results of the op argument.
template <typename TGroupOp, typename TCalyxLibOp, typename TSrcOp>
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const {
LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op,
calyx::RegisterOp srcReg = nullptr,
calyx::RegisterOp dstReg = nullptr) const {
return buildLibraryOp<TGroupOp, TCalyxLibOp, TSrcOp>(
rewriter, op, op.getOperandTypes(), op->getResultTypes());
rewriter, op, op.getOperandTypes(), op->getResultTypes(), srcReg,
dstReg);
}

/// Creates a group named by the basic block which the input op resides in.
Expand All @@ -411,6 +482,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
auto reg = createRegister(
op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
getState<ComponentLoweringState>().getUniqueName(opName));

// Operation pipelines are not combinational, so a GroupOp is required.
auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
OpBuilder builder(group->getRegion(0));
Expand Down Expand Up @@ -441,6 +513,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
getState<ComponentLoweringState>().registerEvaluatingGroup(
opPipe.getRight(), group);

getState<ComponentLoweringState>().setPipeResReg(out.getDefiningOp(), reg);

return success();
}

Expand Down Expand Up @@ -939,9 +1013,43 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CmpIOp op) const {
auto isPipeLibOp = [](Value val) -> bool {
if (Operation *defOp = val.getDefiningOp()) {
return isa<calyx::MultPipeLibOp, calyx::DivUPipeLibOp,
calyx::DivSPipeLibOp, calyx::RemUPipeLibOp,
calyx::RemSPipeLibOp>(defOp);
}
return false;
};

switch (op.getPredicate()) {
case CmpIPredicate::eq:
case CmpIPredicate::eq: {
StringRef opName = op.getOperationName().split(".").second;
Type width = op.getResult().getType();
auto condReg = createRegister(
op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
getState<ComponentLoweringState>().getUniqueName(opName));

for (auto *user : op->getUsers()) {
if (auto ifOp = dyn_cast<scf::IfOp>(user))
getState<ComponentLoweringState>().setCondReg(ifOp, condReg);
}

bool isSequential = isPipeLibOp(op.getLhs()) || isPipeLibOp(op.getRhs());
if (isSequential) {
calyx::RegisterOp pipeResReg;
if (isPipeLibOp(op.getLhs()))
pipeResReg = getState<ComponentLoweringState>().getPipeResReg(
op.getLhs().getDefiningOp());
else
pipeResReg = getState<ComponentLoweringState>().getPipeResReg(
op.getRhs().getDefiningOp());

return buildLibraryOp<calyx::GroupOp, calyx::EqLibOp>(
rewriter, op, pipeResReg, condReg);
}
return buildLibraryOp<calyx::CombGroupOp, calyx::EqLibOp>(rewriter, op);
}
case CmpIPredicate::ne:
return buildLibraryOp<calyx::CombGroupOp, calyx::NeqLibOp>(rewriter, op);
case CmpIPredicate::uge:
Expand Down Expand Up @@ -1535,11 +1643,16 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
Location loc = ifOp->getLoc();

auto cond = ifOp.getCondition();
auto condGroup = getState<ComponentLoweringState>()
.getEvaluatingGroup<calyx::CombGroupOp>(cond);

auto symbolAttr = FlatSymbolRefAttr::get(
StringAttr::get(getContext(), condGroup.getSymName()));
FlatSymbolRefAttr symbolAttr = nullptr;
auto condReg = getState<ComponentLoweringState>().getCondReg(ifOp);
if (!condReg) {
auto condGroup = getState<ComponentLoweringState>()
.getEvaluatingGroup<calyx::CombGroupOp>(cond);

symbolAttr = FlatSymbolRefAttr::get(
StringAttr::get(getContext(), condGroup.getSymName()));
}

bool initElse = !ifOp.getElseRegion().empty();
auto ifCtrlOp = rewriter.create<calyx::IfOp>(
Expand Down
Loading

0 comments on commit b42c4a0

Please sign in to comment.