diff --git a/include/scalehls/Dialect/HLS/Utils.h b/include/scalehls/Dialect/HLS/Utils.h index c53909ad..664c9ca2 100644 --- a/include/scalehls/Dialect/HLS/Utils.h +++ b/include/scalehls/Dialect/HLS/Utils.h @@ -424,6 +424,15 @@ using ReverseOpIteratorsMap = using OpIteratorsMap = DenseMap>; +//===----------------------------------------------------------------------===// +// Printing +//===----------------------------------------------------------------------===// + +/// Prints dimension and symbol list. +void printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, unsigned numDims, + OpAsmPrinter &printer); + } // namespace scalehls } // namespace mlir diff --git a/lib/Dialect/HLS/CMakeLists.txt b/lib/Dialect/HLS/CMakeLists.txt index a87aa668..efbaba12 100644 --- a/lib/Dialect/HLS/CMakeLists.txt +++ b/lib/Dialect/HLS/CMakeLists.txt @@ -8,4 +8,8 @@ add_mlir_dialect_library(MLIRHLS MLIRHLSEnumsIncGen MLIRHLSAttributesIncGen MLIRHLSInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRAffineAnalysis + MLIRAnalysis ) diff --git a/lib/Dialect/HLS/HLS.cpp b/lib/Dialect/HLS/HLS.cpp index 2a8f862e..6308f27a 100644 --- a/lib/Dialect/HLS/HLS.cpp +++ b/lib/Dialect/HLS/HLS.cpp @@ -53,12 +53,15 @@ struct SimplifyDispatchOrTaskOutputs : public OpRewritePattern { // Identify output values that are used. SmallVector usedOutputs; + SmallVector usedOutputTypes; SmallVector usedResults; for (auto result : op.getResults()) if (result.use_empty()) { hasUnusedPort = true; } else { - usedOutputs.push_back(yield.getOperand(result.getResultNumber())); + auto out = yield.getOperand(result.getResultNumber()); + usedOutputs.push_back(out); + usedOutputTypes.push_back(out.getType()); usedResults.push_back(result); } @@ -68,8 +71,8 @@ struct SimplifyDispatchOrTaskOutputs : public OpRewritePattern { rewriter.replaceOpWithNewOp(yield, usedOutputs); rewriter.setInsertionPoint(op); - auto newTask = - rewriter.create(op.getLoc(), ValueRange(usedOutputs)); + auto newTask = rewriter.create( + op.getLoc(), TypeRange(usedOutputTypes), ValueRange(usedOutputs)); rewriter.inlineRegionBefore(op.getBody(), newTask.getBody(), newTask.getBody().end()); for (auto t : llvm::zip(usedResults, newTask.getResults())) @@ -186,7 +189,7 @@ LogicalResult ToStreamOp::verify() { return success(); } -OpFoldResult ToStreamOp::fold(ArrayRef) { +OpFoldResult ToStreamOp::fold(FoldAdaptor adaptor) { if (auto toValue = getValue().getDefiningOp()) if (toValue.getStream().getType() == getType()) return toValue.getStream(); @@ -200,7 +203,7 @@ LogicalResult ToValueOp::verify() { return success(); } -OpFoldResult ToValueOp::fold(ArrayRef) { +OpFoldResult ToValueOp::fold(FoldAdaptor adaptor) { if (auto toStream = getStream().getDefiningOp()) if (toStream.getValue().getType() == getType()) return toStream.getValue(); @@ -831,7 +834,7 @@ LogicalResult BufferDevectorizeOp::verify() { getInputType()); } -OpFoldResult BufferVectorizeOp::fold(ArrayRef) { +OpFoldResult BufferVectorizeOp::fold(FoldAdaptor adaptor) { if (auto devectorize = getInput().getDefiningOp()) if (devectorize.getInputType() == getType()) return devectorize.getInput(); @@ -1017,10 +1020,12 @@ void AffineSelectOp::getCanonicalizationPatterns(RewritePatternSet &results, } /// Canonicalize an affine if op's conditional (integer set + operands). -OpFoldResult AffineSelectOp::fold(ArrayRef) { +OpFoldResult AffineSelectOp::fold(FoldAdaptor adaptor) { auto set = getIntegerSet(); SmallVector operands(getArgs()); - composeSetAndOperands(set, operands); + auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(), + set.getConstraints(), set.getContext()); + fullyComposeAffineMapAndOperands(&map, &operands); canonicalizeSetAndOperands(&set, &operands); return {}; } @@ -1206,7 +1211,7 @@ PartitionLayoutAttr::verify(function_ref emitError, /// given array shape. SmallVector PartitionLayoutAttr::getActualFactors(ArrayRef shape) { - SmallVector actualFactors; + SmallVector actualFactors; for (auto [size, kind, factor] : llvm::zip(shape, getKinds(), getFactors())) { if (kind == PartitionKind::BLOCK) actualFactors.push_back((size + factor - 1) / factor); diff --git a/lib/Dialect/HLS/Utils.cpp b/lib/Dialect/HLS/Utils.cpp index dccd58b4..a0498cc4 100644 --- a/lib/Dialect/HLS/Utils.cpp +++ b/lib/Dialect/HLS/Utils.cpp @@ -84,8 +84,9 @@ DispatchOp scalehls::dispatchBlock(Block *block) { OpBuilder builder(block, block->begin()); ValueRange returnValues(block->getTerminator()->getOperands()); + TypeRange returnTypes(block->getTerminator()->getOperandTypes()); auto loc = builder.getUnknownLoc(); - auto dispatch = builder.create(loc, returnValues); + auto dispatch = builder.create(loc, returnTypes, returnValues); auto &dispatchBlock = dispatch.getBody().emplaceBlock(); builder.setInsertionPointToEnd(&dispatchBlock); @@ -112,11 +113,16 @@ TaskOp scalehls::fuseOpsIntoTask(ArrayRef ops, // Collect output values. This is not sufficient and may lead to empty-used // outputs, which will be removed during canonicalization. llvm::SetVector outputValues; - for (auto op : ops) - for (auto result : op->getResults()) + llvm::SetVector outputTypes; + for (auto op : ops) { + for (auto result : op->getResults()) { if (llvm::any_of(result.getUsers(), - [&](Operation *user) { return !opsSet.count(user); })) + [&](Operation *user) { return !opsSet.count(user); })) { outputValues.insert(result); + outputTypes.insert(result.getType()); + } + } + } // Create new graph task with all inputs and outputs. auto loc = rewriter.getUnknownLoc(); @@ -124,8 +130,8 @@ TaskOp scalehls::fuseOpsIntoTask(ArrayRef ops, rewriter.setInsertionPoint(ops.front()); else rewriter.setInsertionPoint(ops.back()); - auto task = - rewriter.create(loc, ValueRange(outputValues.getArrayRef())); + auto task = rewriter.create(loc, outputTypes.getArrayRef(), + outputValues.getArrayRef()); auto taskBlock = rewriter.createBlock(&task.getBody()); // Move each targeted op into the new graph task. @@ -159,26 +165,36 @@ NodeOp scalehls::fuseNodeOps(ArrayRef nodes, // Collect inputs, outputs, and params of the new node. llvm::SetVector inputs; + llvm::SmallVector inputTypes; llvm::SmallVector inputTaps; llvm::SmallVector inputLocs; llvm::SetVector outputs; + llvm::SmallVector outputTypes; llvm::SmallVector outputLocs; llvm::SetVector params; + llvm::SmallVector paramTypes; llvm::SmallVector paramLocs; for (auto node : nodes) { - for (auto output : node.getOutputs()) - if (outputs.insert(output)) + for (auto output : node.getOutputs()) { + if (outputs.insert(output)) { + outputTypes.push_back(output.getType()); outputLocs.push_back(output.getLoc()); - for (auto param : node.getParams()) - if (params.insert(param)) + } + } + for (auto param : node.getParams()) { + if (params.insert(param)) { + paramTypes.push_back(param.getType()); paramLocs.push_back(param.getLoc()); + } + } } for (auto node : nodes) for (auto input : llvm::enumerate(node.getInputs())) { if (outputs.count(input.value())) continue; if (inputs.insert(input.value())) { + inputTypes.push_back(input.value().getType()); inputLocs.push_back(input.value().getLoc()); inputTaps.push_back(node.getInputTap(input.index())); } @@ -190,9 +206,9 @@ NodeOp scalehls::fuseNodeOps(ArrayRef nodes, rewriter.getUnknownLoc(), inputs.getArrayRef(), outputs.getArrayRef(), params.getArrayRef(), inputTaps); auto block = rewriter.createBlock(&newNode.getBody()); - block->addArguments(ValueRange(inputs.getArrayRef()), inputLocs); - block->addArguments(ValueRange(outputs.getArrayRef()), outputLocs); - block->addArguments(ValueRange(params.getArrayRef()), paramLocs); + block->addArguments(inputTypes, inputLocs); + block->addArguments(outputTypes, outputLocs); + block->addArguments(paramTypes, paramLocs); // Inline all nodes into the new node. for (auto node : nodes) { @@ -644,7 +660,9 @@ std::pair scalehls::ifAlwaysTrueOrFalse(mlir::AffineIfOp ifOp) { while (llvm::any_of(operands, [](Value v) { return isa_and_nonnull(v.getDefiningOp()); })) { - composeSetAndOperands(set, operands); + auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(), + set.getConstraints(), set.getContext()); + fullyComposeAffineMapAndOperands(&map, &operands); } // Replace the original integer set and operands with the composed integer @@ -1192,3 +1210,12 @@ bool PtrLikeMemRefAccess::operator==(const PtrLikeMemRefAccess &rhs) const { return llvm::all_of(diff.getAffineMap().getResults(), [](AffineExpr e) { return e == 0; }); } + +void scalehls::printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, + unsigned numDims, OpAsmPrinter &printer) { + OperandRange operands(begin, end); + printer << '(' << operands.take_front(numDims) << ')'; + if (operands.size() > numDims) + printer << '[' << operands.drop_front(numDims) << ']'; +} diff --git a/lib/Transforms/Dataflow/EliminateMultiConsumer.cpp b/lib/Transforms/Dataflow/EliminateMultiConsumer.cpp index 6a9494d1..4a65f4ae 100644 --- a/lib/Transforms/Dataflow/EliminateMultiConsumer.cpp +++ b/lib/Transforms/Dataflow/EliminateMultiConsumer.cpp @@ -34,6 +34,7 @@ struct InsertForkNode : public OpRewritePattern { hasChanged = true; rewriter.setInsertionPointAfter(node); SmallVector buffers; + SmallVector bufferTypes; SmallVector bufferLocs; // Insert a buffer for each consumer. @@ -42,6 +43,7 @@ struct InsertForkNode : public OpRewritePattern { output.replaceUsesWithIf( buffer, [&](OpOperand &use) { return use.getOwner() == consumer; }); buffers.push_back(buffer); + bufferTypes.push_back(buffer.getType()); bufferLocs.push_back(loc); } @@ -49,7 +51,7 @@ struct InsertForkNode : public OpRewritePattern { auto fork = rewriter.create(loc, output, buffers); auto block = rewriter.createBlock(&fork.getBody()); auto outputArg = block->addArgument(output.getType(), output.getLoc()); - auto bufferArgs = block->addArguments(ValueRange(buffers), bufferLocs); + auto bufferArgs = block->addArguments(bufferTypes, bufferLocs); // Create explicit copy from the original output to the buffers. rewriter.setInsertionPointToStart(block); diff --git a/lib/Transforms/Dataflow/LegalizeDataflow.cpp b/lib/Transforms/Dataflow/LegalizeDataflow.cpp index 78e9c0a9..0a180d9e 100644 --- a/lib/Transforms/Dataflow/LegalizeDataflow.cpp +++ b/lib/Transforms/Dataflow/LegalizeDataflow.cpp @@ -199,7 +199,7 @@ struct LegalizeDataflow : public LegalizeDataflowBase { auto frozenPatterns = FrozenRewritePatternSet(std::move(patterns)); func.walk([&](ScheduleOp schedule) { - (void)applyOpPatternsAndFold(schedule, frozenPatterns); + (void)applyOpPatternsAndFold(schedule.getOperation(), frozenPatterns); if (llvm::all_of(schedule.getOps(), [](NodeOp node) { return node.getLevel(); })) diff --git a/lib/Transforms/Dataflow/LowerDataflow.cpp b/lib/Transforms/Dataflow/LowerDataflow.cpp index adfb4be4..e377bbeb 100644 --- a/lib/Transforms/Dataflow/LowerDataflow.cpp +++ b/lib/Transforms/Dataflow/LowerDataflow.cpp @@ -28,6 +28,7 @@ struct LowerDispatchToSchedule : public OpRewritePattern { }; SmallVector inputs; + SmallVector inputTypes; SmallVector inputLocs; auto liveins = Liveness(dispatch).getLiveIn(&dispatch.getBody().front()); @@ -35,6 +36,7 @@ struct LowerDispatchToSchedule : public OpRewritePattern { if (dispatch.getBody().isAncestor(livein.getParentRegion())) continue; inputs.push_back(livein); + inputTypes.push_back(livein.getType()); inputLocs.push_back(livein.getLoc()); } @@ -43,7 +45,8 @@ struct LowerDispatchToSchedule : public OpRewritePattern { rewriter.create(rewriter.getUnknownLoc(), inputs); auto scheduleBlock = rewriter.createBlock(&schedule.getBody()); - auto inputArgs = scheduleBlock->addArguments(ValueRange(inputs), inputLocs); + auto inputArgs = + scheduleBlock->addArguments(TypeRange(inputTypes), inputLocs); for (auto t : llvm::zip(inputs, inputArgs)) std::get<0>(t).replaceUsesWithIf(std::get<1>(t), isInDispatch); @@ -72,10 +75,13 @@ struct LowerTaskToNode : public OpRewritePattern { }; SmallVector inputs; + SmallVector inputTypes; SmallVector inputLocs; SmallVector outputs; + SmallVector outputTypes; SmallVector outputLocs; SmallVector params; + SmallVector paramTypes; SmallVector paramLocs; auto liveins = Liveness(task).getLiveIn(&task.getBody().front()); @@ -87,13 +93,16 @@ struct LowerTaskToNode : public OpRewritePattern { auto uses = llvm::make_filter_range(livein.getUses(), isInTask); if (llvm::any_of(uses, [](OpOperand &use) { return isWritten(use); })) { outputs.push_back(livein); + outputTypes.push_back(livein.getType()); outputLocs.push_back(livein.getLoc()); } else { inputs.push_back(livein); + inputTypes.push_back(livein.getType()); inputLocs.push_back(livein.getLoc()); } } else { params.push_back(livein); + paramTypes.push_back(livein.getType()); paramLocs.push_back(livein.getLoc()); } } @@ -103,16 +112,16 @@ struct LowerTaskToNode : public OpRewritePattern { outputs, params); auto nodeBlock = rewriter.createBlock(&node.getBody()); - auto inputArgs = nodeBlock->addArguments(ValueRange(inputs), inputLocs); + auto inputArgs = nodeBlock->addArguments(TypeRange(inputTypes), inputLocs); for (auto t : llvm::zip(inputs, inputArgs)) std::get<0>(t).replaceUsesWithIf(std::get<1>(t), isInTask); auto outputArgs = - node.getBody().addArguments(ValueRange(outputs), outputLocs); + node.getBody().addArguments(TypeRange(outputTypes), outputLocs); for (auto t : llvm::zip(outputs, outputArgs)) std::get<0>(t).replaceUsesWithIf(std::get<1>(t), isInTask); - auto paramArgs = nodeBlock->addArguments(ValueRange(params), paramLocs); + auto paramArgs = nodeBlock->addArguments(TypeRange(paramTypes), paramLocs); for (auto t : llvm::zip(params, paramArgs)) std::get<0>(t).replaceUsesWithIf(std::get<1>(t), isInTask); diff --git a/lib/Transforms/Dataflow/PlaceDataflowBuffer.cpp b/lib/Transforms/Dataflow/PlaceDataflowBuffer.cpp index d6fe9b6c..96b81f3a 100644 --- a/lib/Transforms/Dataflow/PlaceDataflowBuffer.cpp +++ b/lib/Transforms/Dataflow/PlaceDataflowBuffer.cpp @@ -111,13 +111,9 @@ struct PlaceDataflowBuffer auto func = getOperation(); auto context = func.getContext(); - llvm::outs() << "1\n"; - mlir::RewritePatternSet patterns(context); patterns.add(context, threshold, placeExternalBuffer); - (void)applyOpPatternsAndFold(func, std::move(patterns)); - - llvm::outs() << "2\n"; + (void)applyOpPatternsAndFold(func.getOperation(), std::move(patterns)); patterns.clear(); patterns.add(context); diff --git a/lib/Transforms/Directive/CreateAxiInterface.cpp b/lib/Transforms/Directive/CreateAxiInterface.cpp index 301c7812..3caadf61 100644 --- a/lib/Transforms/Directive/CreateAxiInterface.cpp +++ b/lib/Transforms/Directive/CreateAxiInterface.cpp @@ -57,7 +57,7 @@ struct CreateAxiInterface : public CreateAxiInterfaceBase { auto vectorize = cast(*buffer.user_begin()); vectorize->remove(); builder.insert(vectorize); - return vectorize.getResult(); + return cast(vectorize.getResult()); } return buffer; }; diff --git a/lib/Transforms/FuncDuplication.cpp b/lib/Transforms/FuncDuplication.cpp index ceb38dfb..7184ece3 100644 --- a/lib/Transforms/FuncDuplication.cpp +++ b/lib/Transforms/FuncDuplication.cpp @@ -25,6 +25,7 @@ struct SubViewSinkPattern : public OpRewritePattern { assert(func && "function definition not found"); SmallVector newInputs; + SmallVector newInputTypes; bool hasChanged = false; for (auto operand : call->getOperands()) { if (auto subview = operand.getDefiningOp()) { @@ -46,13 +47,17 @@ struct SubViewSinkPattern : public OpRewritePattern { } newInputs.append(subview.operand_begin(), subview.operand_end()); + newInputTypes.append(subview.getOperandTypes().begin(), + subview.getOperandTypes().end()); hasChanged = true; - } else + } else { newInputs.push_back(operand); + newInputTypes.push_back(operand.getType()); + } } if (hasChanged) { - func.setType(rewriter.getFunctionType(ValueRange(newInputs), + func.setType(rewriter.getFunctionType(TypeRange(newInputTypes), func.getResultTypes())); rewriter.setInsertionPoint(call); rewriter.replaceOpWithNewOp(call, func, newInputs); diff --git a/lib/Transforms/Memory/BufferVectorize.cpp b/lib/Transforms/Memory/BufferVectorize.cpp index bb00c601..fe5ab8fa 100644 --- a/lib/Transforms/Memory/BufferVectorize.cpp +++ b/lib/Transforms/Memory/BufferVectorize.cpp @@ -436,7 +436,7 @@ struct BufferVectorize : public BufferVectorizeBase { mlir::RewritePatternSet patterns(context); patterns.add(context); - (void)applyOpPatternsAndFold(func, std::move(patterns)); + (void)applyOpPatternsAndFold(func.getOperation(), std::move(patterns)); patterns.clear(); patterns.add(context); @@ -445,7 +445,8 @@ struct BufferVectorize : public BufferVectorizeBase { patterns.add(context); patterns.add(context); patterns.add(context); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsAndFoldGreedily(func.getOperation(), + std::move(patterns)); } }; } // namespace diff --git a/lib/Transforms/Memory/CollapseMemrefUnitDims.cpp b/lib/Transforms/Memory/CollapseMemrefUnitDims.cpp index a716755d..f09ed7f5 100644 --- a/lib/Transforms/Memory/CollapseMemrefUnitDims.cpp +++ b/lib/Transforms/Memory/CollapseMemrefUnitDims.cpp @@ -117,7 +117,7 @@ struct CollapseMemrefUnitDims mlir::RewritePatternSet patterns(context); patterns.add(context); - (void)applyOpPatternsAndFold(func, std::move(patterns)); + (void)applyOpPatternsAndFold(func.getOperation(), std::move(patterns)); } }; } // namespace diff --git a/lib/Transforms/Tensor/TosaSimplifyGraph.cpp b/lib/Transforms/Tensor/TosaSimplifyGraph.cpp index 6cebb0a5..49e6a6d5 100644 --- a/lib/Transforms/Tensor/TosaSimplifyGraph.cpp +++ b/lib/Transforms/Tensor/TosaSimplifyGraph.cpp @@ -61,7 +61,7 @@ struct RewriteElmwUnary : public OpRewritePattern { return failure(); elmw->getOpOperand(0).set(transpose.getInput1()); - elmw.getOutput().setType(transpose.getInput1().getType()); + elmw.getOutput().setType(cast(transpose.getInput1().getType())); rewriter.setInsertionPointAfter(elmw); auto cloneTranspose = cast(rewriter.clone(*transpose)); @@ -93,7 +93,8 @@ struct RewriteElmwBinary : public OpRewritePattern { elmw->getOpOperand(0).set(input1Transpose.getInput1()); elmw->getOpOperand(1).set(input2Transpose.getInput1()); - elmw.getOutput().setType(input1Transpose.getInput1().getType()); + elmw.getOutput().setType( + cast(input1Transpose.getInput1().getType())); rewriter.setInsertionPointAfter(elmw); auto cloneTranspose =