From c6580f2130ecfcac935d66bc18f2e6c16b85f5d5 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Mon, 19 Feb 2024 23:30:33 -0600 Subject: [PATCH] Use Split/MergeIteration and Expand/CollapseShape op in the transformation --- lib/Dialect/HLS/IR/HLSOps.cpp | 79 +++++++++++-------- .../HLS/TransformOps/HLSTransformOps.cpp | 77 +++++++++++++----- 2 files changed, 103 insertions(+), 53 deletions(-) diff --git a/lib/Dialect/HLS/IR/HLSOps.cpp b/lib/Dialect/HLS/IR/HLSOps.cpp index 317d5161..bd02a4af 100644 --- a/lib/Dialect/HLS/IR/HLSOps.cpp +++ b/lib/Dialect/HLS/IR/HLSOps.cpp @@ -343,8 +343,12 @@ verifyIterationReassociation(ArrayRef reassociation, } LogicalResult StreamSplitIterationOp::verify() { - if (getInputType().isCastableWith(getOutputType())) - return emitOpError("input and output are not castable"); + if (!getInputType().isCastableWith(getOutputType())) { + auto diag = emitOpError("input and output are not castable"); + diag << "input shape: " << getInputType().getShape() + << ", output shape: " << getOutputType().getShape(); + return diag; + } return verifyIterationReassociation(getReassociationIndices(), getInputType(), getOutputType(), *this); } @@ -367,10 +371,14 @@ OpFoldResult StreamSplitIterationOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// LogicalResult StreamMergeIterationOp::verify() { - if (getInputType().isCastableWith(getOutputType())) - return emitOpError("input and output are not castable"); - return verifyIterationReassociation(getReassociationIndices(), getInputType(), - getOutputType(), *this); + if (!getInputType().isCastableWith(getOutputType())) { + auto diag = emitOpError("input and output are not castable"); + diag << "input shape: " << getInputType().getShape() + << ", output shape: " << getOutputType().getShape(); + return diag; + } + return verifyIterationReassociation(getReassociationIndices(), + getOutputType(), getInputType(), *this); } OpFoldResult StreamMergeIterationOp::fold(FoldAdaptor adaptor) { @@ -385,27 +393,24 @@ static LogicalResult verifyShapeReassociation(ArrayRef reassociation, StreamType lowType, StreamType highType, Operation *op) { - // if (lowType.getIterTripCounts() != lowType.getIterTripCounts() || - // lowType.getIterSteps() != lowType.getIterSteps()) - // return op->emitOpError("input and output iteration trip counts or steps " - // "doesn't match"); - - // unsigned highIndex = 0; - // for (auto lowExpr : lowType.getIterMap().getResults()) { - // if (auto lowDimExpr = dyn_cast(lowExpr)) { - // auto indices = reassociation[lowDimExpr.getPosition()]; - // for (auto index : indices) { - // auto highDimExpr = - // dyn_cast(highType.getIterMap().getResult(highIndex)); - // if (!highDimExpr || highDimExpr.getPosition() != index) - // return op->emitOpError( - // "reassociation doesn't align with input/output iteration - // maps"); - // highIndex++; - // } - // } else - // highIndex++; - // } + if (lowType.getIterTripCounts() != lowType.getIterTripCounts() || + lowType.getIterSteps() != lowType.getIterSteps()) + return op->emitOpError("input and output iteration trip counts or steps " + "doesn't match"); + + auto lowShape = lowType.getShape(); + auto highShape = highType.getShape(); + if (reassociation.size() != lowShape.size()) + return op->emitOpError("reassociation size doesn't align with input type"); + + for (auto [indices, lowDimSize] : llvm::zip(reassociation, lowShape)) { + int64_t highDimSizeProduct = 1; + for (auto index : indices) + highDimSizeProduct *= highShape[index]; + if (lowDimSize != highDimSizeProduct) + return op->emitOpError("reassociation doesn't align with input/output " + "tensor shape"); + } return success(); } @@ -427,8 +432,8 @@ OpFoldResult StreamExpandShapeOp::fold(FoldAdaptor adaptor) { LogicalResult StreamCollapseShapeOp::verify() { if (getInputType().getDataType() != getOutputType().getDataType()) return emitOpError("input and output data type doesn't match"); - return verifyShapeReassociation(getReassociationIndices(), getInputType(), - getOutputType(), *this); + return verifyShapeReassociation(getReassociationIndices(), getOutputType(), + getInputType(), *this); } OpFoldResult StreamCollapseShapeOp::fold(FoldAdaptor adaptor) { @@ -442,8 +447,12 @@ OpFoldResult StreamCollapseShapeOp::fold(FoldAdaptor adaptor) { LogicalResult StreamBufferOp::verify() { auto inputType = getInput().getType(); auto outputType = getOutput().getType(); - if (!inputType.isCastableWith(outputType)) - return emitOpError("input and output are not castable"); + if (!inputType.isCastableWith(outputType)) { + auto diag = emitOpError("input and output are not castable"); + diag << "input shape: " << inputType.getShape() + << ", output shape: " << outputType.getShape(); + return diag; + } if (getLoopIndex() > inputType.getIterTripCounts().size()) return emitOpError("buffer loop index is out of loop range"); @@ -474,8 +483,12 @@ OpFoldResult StreamBufferOp::fold(FoldAdaptor adaptor) { LogicalResult StreamCastOp::verify() { auto inputType = getInput().getType(); auto outputType = getOutput().getType(); - if (!inputType.isCastableWith(outputType)) - return emitOpError("input and output are not castable"); + if (!inputType.isCastableWith(outputType)) { + auto diag = emitOpError("input and output are not castable"); + diag << "input shape: " << inputType.getShape() + << ", output shape: " << outputType.getShape(); + return diag; + } return success(); } diff --git a/lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp b/lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp index 5fd7fd7c..cf259f64 100644 --- a/lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp +++ b/lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp @@ -257,21 +257,28 @@ LogicalResult transform::HLSConvertExpandShapeToStreamOp::verify() { return success(); } +/// Given the low-ranked and high-ranked tensor types, and the resulting +/// low-ranked and high-ranked element shapes, this function constructs three +/// stream types: +/// 1. The low-element-ranked and low-iteration-ranked stream type. +/// 2. The high-element-ranked and low-iteration-ranked stream type. +/// 3. The high-element-ranked and high-iteration-ranked stream type. template -static std::optional> -getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType, - RankedTensorType highType, - ArrayRef lowElementShape, - ArrayRef highElementShape) { +static std::optional< + std::tuple> +getReshapeStreamTypes(OpTy reshapeOp, RankedTensorType lowType, + RankedTensorType highType, + ArrayRef lowElementShape, + ArrayRef highElementShape) { // The low and high types must have static shapes. if (!lowType.hasStaticShape() || !highType.hasStaticShape()) return std::nullopt; - SmallVector lowIterTripCounts; - SmallVector lowIterSteps; + SmallVector lowIterTripCounts, lowIterSteps; SmallVector lowIterExprs; - SmallVector highIterTripCounts; - SmallVector highIterSteps; + SmallVector lowToHighIterTripCounts, lowToHighIterSteps; + SmallVector lowToHighIterExprs; + SmallVector highIterTripCounts, highIterSteps; SmallVector highIterExprs; // Collect the iteration shape and affine map of the streaming channel. @@ -286,7 +293,10 @@ getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType, lowIterSteps.push_back(lowElementDimSize); lowIterExprs.push_back(getAffineDimExpr(lowDim, reshapeOp.getContext())); - for (auto highDim : highDims) { + auto localIterExprs = SmallVector( + highDims.size(), getAffineDimExpr(lowDim, reshapeOp.getContext())); + + for (auto [index, highDim] : llvm::enumerate(highDims)) { auto highDimSize = highType.getDimSize(highDim); auto highElementDimSize = highElementShape[highDim]; if (highDimSize % highElementDimSize != 0) @@ -296,7 +306,17 @@ getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType, highIterSteps.push_back(highElementDimSize); highIterExprs.push_back( getAffineDimExpr(highDim, reshapeOp.getContext())); + + for (auto &localIterExpr : + llvm::drop_end(localIterExprs, highDims.size() - index)) + localIterExpr = localIterExpr.floorDiv(highDimSize); + if (index != 0) + localIterExprs[index] = localIterExprs[index] % highDimSize; } + + lowToHighIterTripCounts.push_back(lowDimSize / lowElementDimSize); + lowToHighIterSteps.push_back(lowElementDimSize); + lowToHighIterExprs.append(localIterExprs); } // Construct the low stream type. @@ -308,6 +328,16 @@ getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType, hls::StreamType::get(lowElementType, lowIterTripCounts, lowIterSteps, lowIterMap, lowType.getNumElements()); + // Construct the low-to-high stream type. + auto lowToHighIterMap = + AffineMap::get(lowToHighIterTripCounts.size(), 0, lowToHighIterExprs, + reshapeOp.getContext()); + auto lowToHighElementType = + RankedTensorType::get(highElementShape, lowType.getElementType()); + auto lowToHighStreamType = hls::StreamType::get( + lowToHighElementType, lowToHighIterTripCounts, lowToHighIterSteps, + lowToHighIterMap, lowType.getNumElements()); + // Construct the high stream type. auto highIterMap = AffineMap::get(highIterTripCounts.size(), 0, highIterExprs, reshapeOp.getContext()); @@ -317,7 +347,7 @@ getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType, hls::StreamType::get(highElementType, highIterTripCounts, highIterSteps, highIterMap, highType.getNumElements()); - return std::make_pair(lowStreamType, highStreamType); + return std::make_tuple(lowStreamType, lowToHighStreamType, highStreamType); } DiagnosedSilenceableFailure @@ -326,21 +356,25 @@ transform::HLSConvertExpandShapeToStreamOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &state) { // Construct the source and result stream types. - auto streamTypes = getLowAndHighRankStreamTypes( + auto streamTypes = getReshapeStreamTypes( expandShape, expandShape.getSrcType(), expandShape.getResultType(), getSourceElementShape(), getResultElementShape()); if (!streamTypes) return emitDefaultSilenceableFailure(expandShape); // Convert the expand_shape op to stream ops and replace its uses. + auto loc = expandShape.getLoc(); rewriter.setInsertionPoint(expandShape); auto sourceStream = rewriter.create( - rewriter.getUnknownLoc(), streamTypes->first, expandShape.getSrc()); + loc, std::get<0>(*streamTypes), expandShape.getSrc()); auto streamExpandShape = rewriter.create( - rewriter.getUnknownLoc(), streamTypes->second, sourceStream, + loc, std::get<1>(*streamTypes), sourceStream, + expandShape.getReassociation()); + auto streamSplitIteration = rewriter.create( + loc, std::get<2>(*streamTypes), streamExpandShape, expandShape.getReassociation()); auto resultTensor = rewriter.create( - rewriter.getUnknownLoc(), expandShape.getResultType(), streamExpandShape); + loc, expandShape.getResultType(), streamSplitIteration); rewriter.replaceAllUsesWith(expandShape, resultTensor); results.push_back(sourceStream); @@ -364,22 +398,25 @@ transform::HLSConvertCollapseShapeToStreamOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &state) { // Construct the source and result stream types. - auto streamTypes = getLowAndHighRankStreamTypes( + auto streamTypes = getReshapeStreamTypes( collapseShape, collapseShape.getResultType(), collapseShape.getSrcType(), getResultElementShape(), getSourceElementShape()); if (!streamTypes) return emitDefaultSilenceableFailure(collapseShape); // Convert the expand_shape op to stream ops and replace its uses. + auto loc = collapseShape.getLoc(); rewriter.setInsertionPoint(collapseShape); auto sourceStream = rewriter.create( - rewriter.getUnknownLoc(), streamTypes->second, collapseShape.getSrc()); + loc, std::get<2>(*streamTypes), collapseShape.getSrc()); + auto streamMergeIteration = rewriter.create( + loc, std::get<1>(*streamTypes), sourceStream, + collapseShape.getReassociation()); auto streamCollapseShape = rewriter.create( - rewriter.getUnknownLoc(), streamTypes->first, sourceStream, + loc, std::get<0>(*streamTypes), streamMergeIteration, collapseShape.getReassociation()); auto resultTensor = rewriter.create( - rewriter.getUnknownLoc(), collapseShape.getResultType(), - streamCollapseShape); + loc, collapseShape.getResultType(), streamCollapseShape); rewriter.replaceAllUsesWith(collapseShape, resultTensor); results.push_back(sourceStream);