diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index 618c8db375de..5118d5c25d28 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -125,6 +125,20 @@ using Scheduleable = class IfLoweringStateInterface { public: + void setCondReg(scf::IfOp op, calyx::RegisterOp regOp) { + Operation *operation = op.getOperation(); + assert(condReg.count(operation) == 0 && + "A condition register was already set for this scf::IfOp!\n"); + condReg[operation] = regOp; + } + + calyx::RegisterOp getCondReg(scf::IfOp op) { + auto it = condReg.find(op.getOperation()); + if (it != condReg.end()) + return it->second; + return nullptr; + } + void setThenGroup(scf::IfOp op, calyx::GroupOp group) { Operation *operation = op.getOperation(); assert(thenGroup.count(operation) == 0 && @@ -172,6 +186,7 @@ class IfLoweringStateInterface { } private: + DenseMap condReg; DenseMap thenGroup; DenseMap elseGroup; DenseMap> resultRegs; @@ -240,6 +255,28 @@ class ForLoopLoweringStateInterface } }; +class PipeOpLoweringStateInterface { +public: + void setPipeResReg(Operation *op, calyx::RegisterOp reg) { + assert(isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op)); + assert(resultRegs.count(op) == 0 && + "A register was already set for this pipe operation!\n"); + resultRegs[op] = reg; + } + // Get the register for a specific pipe operation + calyx::RegisterOp getPipeResReg(Operation *op) { + auto it = resultRegs.find(op); + assert(it != resultRegs.end() && + "No register was set for this pipe operation!\n"); + return it->second; + } + +private: + DenseMap resultRegs; +}; + /// Handles the current state of lowering of a Calyx component. It is mainly /// used as a key/value store for recording information during partial lowering, /// which is required at later lowering passes. @@ -247,6 +284,7 @@ class ComponentLoweringState : public calyx::ComponentLoweringStateInterface, public WhileLoopLoweringStateInterface, public ForLoopLoweringStateInterface, public IfLoweringStateInterface, + public PipeOpLoweringStateInterface, public calyx::SchedulerInterface { public: ComponentLoweringState(calyx::ComponentOp component) @@ -339,7 +377,12 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { /// source operation TSrcOp. template LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, - TypeRange srcTypes, TypeRange dstTypes) const { + TypeRange srcTypes, TypeRange dstTypes, + calyx::RegisterOp srcReg = nullptr, + calyx::RegisterOp dstReg = nullptr) const { + assert((srcReg && dstReg) || (!srcReg && !dstReg)); + bool isSequential = srcReg && dstReg; + SmallVector types; llvm::append_range(types, srcTypes); llvm::append_range(types, dstTypes); @@ -365,26 +408,54 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { /// Create assignments to the inputs of the library op. auto group = createGroupForOp(rewriter, op); + + if (isSequential) { + auto groupOp = cast(group); + getState().addBlockScheduleable(op->getBlock(), + groupOp); + } + rewriter.setInsertionPointToEnd(group.getBodyBlock()); - for (auto dstOp : enumerate(opInputPorts)) - rewriter.create(op.getLoc(), dstOp.value(), - op->getOperand(dstOp.index())); + + for (auto dstOp : enumerate(opInputPorts)) { + if (isSequential) + rewriter.create(op.getLoc(), dstOp.value(), + srcReg.getOut()); + else + rewriter.create(op.getLoc(), dstOp.value(), + op->getOperand(dstOp.index())); + } /// Replace the result values of the source operator with the new operator. for (auto res : enumerate(opOutputPorts)) { getState().registerEvaluatingGroup(res.value(), group); - op->getResult(res.index()).replaceAllUsesWith(res.value()); + if (isSequential) + op->getResult(res.index()).replaceAllUsesWith(dstReg.getOut()); + else + op->getResult(res.index()).replaceAllUsesWith(res.value()); + } + + if (isSequential) { + auto groupOp = cast(group); + buildAssignmentsForRegisterWrite( + rewriter, groupOp, + getState().getComponentOp(), dstReg, + calyxOp.getOut()); } + return success(); } /// buildLibraryOp which provides in- and output types based on the operands /// and results of the op argument. template - LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op) const { + LogicalResult buildLibraryOp(PatternRewriter &rewriter, TSrcOp op, + calyx::RegisterOp srcReg = nullptr, + calyx::RegisterOp dstReg = nullptr) const { return buildLibraryOp( - rewriter, op, op.getOperandTypes(), op->getResultTypes()); + rewriter, op, op.getOperandTypes(), op->getResultTypes(), srcReg, + dstReg); } /// Creates a group named by the basic block which the input op resides in. @@ -411,6 +482,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { auto reg = createRegister( op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(), getState().getUniqueName(opName)); + // Operation pipelines are not combinational, so a GroupOp is required. auto group = createGroupForOp(rewriter, op); OpBuilder builder(group->getRegion(0)); @@ -441,6 +513,8 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { getState().registerEvaluatingGroup( opPipe.getRight(), group); + getState().setPipeResReg(out.getDefiningOp(), reg); + return success(); } @@ -939,9 +1013,43 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, CmpIOp op) const { + auto isPipeLibOp = [](Value val) -> bool { + if (Operation *defOp = val.getDefiningOp()) { + return isa(defOp); + } + return false; + }; + switch (op.getPredicate()) { - case CmpIPredicate::eq: + case CmpIPredicate::eq: { + StringRef opName = op.getOperationName().split(".").second; + Type width = op.getResult().getType(); + auto condReg = createRegister( + op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(), + getState().getUniqueName(opName)); + + for (auto *user : op->getUsers()) { + if (auto ifOp = dyn_cast(user)) + getState().setCondReg(ifOp, condReg); + } + + bool isSequential = isPipeLibOp(op.getLhs()) || isPipeLibOp(op.getRhs()); + if (isSequential) { + calyx::RegisterOp pipeResReg; + if (isPipeLibOp(op.getLhs())) + pipeResReg = getState().getPipeResReg( + op.getLhs().getDefiningOp()); + else + pipeResReg = getState().getPipeResReg( + op.getRhs().getDefiningOp()); + + return buildLibraryOp( + rewriter, op, pipeResReg, condReg); + } return buildLibraryOp(rewriter, op); + } case CmpIPredicate::ne: return buildLibraryOp(rewriter, op); case CmpIPredicate::uge: @@ -1535,11 +1643,16 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { Location loc = ifOp->getLoc(); auto cond = ifOp.getCondition(); - auto condGroup = getState() - .getEvaluatingGroup(cond); - auto symbolAttr = FlatSymbolRefAttr::get( - StringAttr::get(getContext(), condGroup.getSymName())); + FlatSymbolRefAttr symbolAttr = nullptr; + auto condReg = getState().getCondReg(ifOp); + if (!condReg) { + auto condGroup = getState() + .getEvaluatingGroup(cond); + + symbolAttr = FlatSymbolRefAttr::get( + StringAttr::get(getContext(), condGroup.getSymName())); + } bool initElse = !ifOp.getElseRegion().empty(); auto ifCtrlOp = rewriter.create(