Skip to content

Commit

Permalink
Use Split/MergeIteration and Expand/CollapseShape op in the transform…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
hanchenye committed Feb 20, 2024
1 parent 8783e27 commit c6580f2
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 53 deletions.
79 changes: 46 additions & 33 deletions lib/Dialect/HLS/IR/HLSOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,12 @@ verifyIterationReassociation(ArrayRef<ReassociationIndices> reassociation,
}

LogicalResult StreamSplitIterationOp::verify() {
if (getInputType().isCastableWith(getOutputType()))
return emitOpError("input and output are not castable");
if (!getInputType().isCastableWith(getOutputType())) {
auto diag = emitOpError("input and output are not castable");
diag << "input shape: " << getInputType().getShape()
<< ", output shape: " << getOutputType().getShape();
return diag;
}
return verifyIterationReassociation(getReassociationIndices(), getInputType(),
getOutputType(), *this);
}
Expand All @@ -367,10 +371,14 @@ OpFoldResult StreamSplitIterationOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

LogicalResult StreamMergeIterationOp::verify() {
if (getInputType().isCastableWith(getOutputType()))
return emitOpError("input and output are not castable");
return verifyIterationReassociation(getReassociationIndices(), getInputType(),
getOutputType(), *this);
if (!getInputType().isCastableWith(getOutputType())) {
auto diag = emitOpError("input and output are not castable");
diag << "input shape: " << getInputType().getShape()
<< ", output shape: " << getOutputType().getShape();
return diag;
}
return verifyIterationReassociation(getReassociationIndices(),
getOutputType(), getInputType(), *this);
}

OpFoldResult StreamMergeIterationOp::fold(FoldAdaptor adaptor) {
Expand All @@ -385,27 +393,24 @@ static LogicalResult
verifyShapeReassociation(ArrayRef<ReassociationIndices> reassociation,
StreamType lowType, StreamType highType,
Operation *op) {
// if (lowType.getIterTripCounts() != lowType.getIterTripCounts() ||
// lowType.getIterSteps() != lowType.getIterSteps())
// return op->emitOpError("input and output iteration trip counts or steps "
// "doesn't match");

// unsigned highIndex = 0;
// for (auto lowExpr : lowType.getIterMap().getResults()) {
// if (auto lowDimExpr = dyn_cast<AffineDimExpr>(lowExpr)) {
// auto indices = reassociation[lowDimExpr.getPosition()];
// for (auto index : indices) {
// auto highDimExpr =
// dyn_cast<AffineDimExpr>(highType.getIterMap().getResult(highIndex));
// if (!highDimExpr || highDimExpr.getPosition() != index)
// return op->emitOpError(
// "reassociation doesn't align with input/output iteration
// maps");
// highIndex++;
// }
// } else
// highIndex++;
// }
if (lowType.getIterTripCounts() != lowType.getIterTripCounts() ||
lowType.getIterSteps() != lowType.getIterSteps())
return op->emitOpError("input and output iteration trip counts or steps "
"doesn't match");

auto lowShape = lowType.getShape();
auto highShape = highType.getShape();
if (reassociation.size() != lowShape.size())
return op->emitOpError("reassociation size doesn't align with input type");

for (auto [indices, lowDimSize] : llvm::zip(reassociation, lowShape)) {
int64_t highDimSizeProduct = 1;
for (auto index : indices)
highDimSizeProduct *= highShape[index];
if (lowDimSize != highDimSizeProduct)
return op->emitOpError("reassociation doesn't align with input/output "
"tensor shape");
}
return success();
}

Expand All @@ -427,8 +432,8 @@ OpFoldResult StreamExpandShapeOp::fold(FoldAdaptor adaptor) {
LogicalResult StreamCollapseShapeOp::verify() {
if (getInputType().getDataType() != getOutputType().getDataType())
return emitOpError("input and output data type doesn't match");
return verifyShapeReassociation(getReassociationIndices(), getInputType(),
getOutputType(), *this);
return verifyShapeReassociation(getReassociationIndices(), getOutputType(),
getInputType(), *this);
}

OpFoldResult StreamCollapseShapeOp::fold(FoldAdaptor adaptor) {
Expand All @@ -442,8 +447,12 @@ OpFoldResult StreamCollapseShapeOp::fold(FoldAdaptor adaptor) {
LogicalResult StreamBufferOp::verify() {
auto inputType = getInput().getType();
auto outputType = getOutput().getType();
if (!inputType.isCastableWith(outputType))
return emitOpError("input and output are not castable");
if (!inputType.isCastableWith(outputType)) {
auto diag = emitOpError("input and output are not castable");
diag << "input shape: " << inputType.getShape()
<< ", output shape: " << outputType.getShape();
return diag;
}

if (getLoopIndex() > inputType.getIterTripCounts().size())
return emitOpError("buffer loop index is out of loop range");
Expand Down Expand Up @@ -474,8 +483,12 @@ OpFoldResult StreamBufferOp::fold(FoldAdaptor adaptor) {
LogicalResult StreamCastOp::verify() {
auto inputType = getInput().getType();
auto outputType = getOutput().getType();
if (!inputType.isCastableWith(outputType))
return emitOpError("input and output are not castable");
if (!inputType.isCastableWith(outputType)) {
auto diag = emitOpError("input and output are not castable");
diag << "input shape: " << inputType.getShape()
<< ", output shape: " << outputType.getShape();
return diag;
}
return success();
}

Expand Down
77 changes: 57 additions & 20 deletions lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,21 +257,28 @@ LogicalResult transform::HLSConvertExpandShapeToStreamOp::verify() {
return success();
}

/// Given the low-ranked and high-ranked tensor types, and the resulting
/// low-ranked and high-ranked element shapes, this function constructs three
/// stream types:
/// 1. The low-element-ranked and low-iteration-ranked stream type.
/// 2. The high-element-ranked and low-iteration-ranked stream type.
/// 3. The high-element-ranked and high-iteration-ranked stream type.
template <typename OpTy>
static std::optional<std::pair<hls::StreamType, hls::StreamType>>
getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType,
RankedTensorType highType,
ArrayRef<int64_t> lowElementShape,
ArrayRef<int64_t> highElementShape) {
static std::optional<
std::tuple<hls::StreamType, hls::StreamType, hls::StreamType>>
getReshapeStreamTypes(OpTy reshapeOp, RankedTensorType lowType,
RankedTensorType highType,
ArrayRef<int64_t> lowElementShape,
ArrayRef<int64_t> highElementShape) {
// The low and high types must have static shapes.
if (!lowType.hasStaticShape() || !highType.hasStaticShape())
return std::nullopt;

SmallVector<int64_t> lowIterTripCounts;
SmallVector<int64_t> lowIterSteps;
SmallVector<int64_t> lowIterTripCounts, lowIterSteps;
SmallVector<AffineExpr> lowIterExprs;
SmallVector<int64_t> highIterTripCounts;
SmallVector<int64_t> highIterSteps;
SmallVector<int64_t> lowToHighIterTripCounts, lowToHighIterSteps;
SmallVector<AffineExpr> lowToHighIterExprs;
SmallVector<int64_t> highIterTripCounts, highIterSteps;
SmallVector<AffineExpr> highIterExprs;

// Collect the iteration shape and affine map of the streaming channel.
Expand All @@ -286,7 +293,10 @@ getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType,
lowIterSteps.push_back(lowElementDimSize);
lowIterExprs.push_back(getAffineDimExpr(lowDim, reshapeOp.getContext()));

for (auto highDim : highDims) {
auto localIterExprs = SmallVector<AffineExpr>(
highDims.size(), getAffineDimExpr(lowDim, reshapeOp.getContext()));

for (auto [index, highDim] : llvm::enumerate(highDims)) {
auto highDimSize = highType.getDimSize(highDim);
auto highElementDimSize = highElementShape[highDim];
if (highDimSize % highElementDimSize != 0)
Expand All @@ -296,7 +306,17 @@ getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType,
highIterSteps.push_back(highElementDimSize);
highIterExprs.push_back(
getAffineDimExpr(highDim, reshapeOp.getContext()));

for (auto &localIterExpr :
llvm::drop_end(localIterExprs, highDims.size() - index))
localIterExpr = localIterExpr.floorDiv(highDimSize);
if (index != 0)
localIterExprs[index] = localIterExprs[index] % highDimSize;
}

lowToHighIterTripCounts.push_back(lowDimSize / lowElementDimSize);
lowToHighIterSteps.push_back(lowElementDimSize);
lowToHighIterExprs.append(localIterExprs);
}

// Construct the low stream type.
Expand All @@ -308,6 +328,16 @@ getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType,
hls::StreamType::get(lowElementType, lowIterTripCounts, lowIterSteps,
lowIterMap, lowType.getNumElements());

// Construct the low-to-high stream type.
auto lowToHighIterMap =
AffineMap::get(lowToHighIterTripCounts.size(), 0, lowToHighIterExprs,
reshapeOp.getContext());
auto lowToHighElementType =
RankedTensorType::get(highElementShape, lowType.getElementType());
auto lowToHighStreamType = hls::StreamType::get(
lowToHighElementType, lowToHighIterTripCounts, lowToHighIterSteps,
lowToHighIterMap, lowType.getNumElements());

// Construct the high stream type.
auto highIterMap = AffineMap::get(highIterTripCounts.size(), 0, highIterExprs,
reshapeOp.getContext());
Expand All @@ -317,7 +347,7 @@ getLowAndHighRankStreamTypes(OpTy reshapeOp, RankedTensorType lowType,
hls::StreamType::get(highElementType, highIterTripCounts, highIterSteps,
highIterMap, highType.getNumElements());

return std::make_pair(lowStreamType, highStreamType);
return std::make_tuple(lowStreamType, lowToHighStreamType, highStreamType);
}

DiagnosedSilenceableFailure
Expand All @@ -326,21 +356,25 @@ transform::HLSConvertExpandShapeToStreamOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Construct the source and result stream types.
auto streamTypes = getLowAndHighRankStreamTypes(
auto streamTypes = getReshapeStreamTypes(
expandShape, expandShape.getSrcType(), expandShape.getResultType(),
getSourceElementShape(), getResultElementShape());
if (!streamTypes)
return emitDefaultSilenceableFailure(expandShape);

// Convert the expand_shape op to stream ops and replace its uses.
auto loc = expandShape.getLoc();
rewriter.setInsertionPoint(expandShape);
auto sourceStream = rewriter.create<hls::TensorToStreamOp>(
rewriter.getUnknownLoc(), streamTypes->first, expandShape.getSrc());
loc, std::get<0>(*streamTypes), expandShape.getSrc());
auto streamExpandShape = rewriter.create<hls::StreamExpandShapeOp>(
rewriter.getUnknownLoc(), streamTypes->second, sourceStream,
loc, std::get<1>(*streamTypes), sourceStream,
expandShape.getReassociation());
auto streamSplitIteration = rewriter.create<hls::StreamSplitIterationOp>(
loc, std::get<2>(*streamTypes), streamExpandShape,
expandShape.getReassociation());
auto resultTensor = rewriter.create<hls::StreamToTensorOp>(
rewriter.getUnknownLoc(), expandShape.getResultType(), streamExpandShape);
loc, expandShape.getResultType(), streamSplitIteration);
rewriter.replaceAllUsesWith(expandShape, resultTensor);

results.push_back(sourceStream);
Expand All @@ -364,22 +398,25 @@ transform::HLSConvertCollapseShapeToStreamOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Construct the source and result stream types.
auto streamTypes = getLowAndHighRankStreamTypes(
auto streamTypes = getReshapeStreamTypes(
collapseShape, collapseShape.getResultType(), collapseShape.getSrcType(),
getResultElementShape(), getSourceElementShape());
if (!streamTypes)
return emitDefaultSilenceableFailure(collapseShape);

// Convert the expand_shape op to stream ops and replace its uses.
auto loc = collapseShape.getLoc();
rewriter.setInsertionPoint(collapseShape);
auto sourceStream = rewriter.create<hls::TensorToStreamOp>(
rewriter.getUnknownLoc(), streamTypes->second, collapseShape.getSrc());
loc, std::get<2>(*streamTypes), collapseShape.getSrc());
auto streamMergeIteration = rewriter.create<hls::StreamMergeIterationOp>(
loc, std::get<1>(*streamTypes), sourceStream,
collapseShape.getReassociation());
auto streamCollapseShape = rewriter.create<hls::StreamCollapseShapeOp>(
rewriter.getUnknownLoc(), streamTypes->first, sourceStream,
loc, std::get<0>(*streamTypes), streamMergeIteration,
collapseShape.getReassociation());
auto resultTensor = rewriter.create<hls::StreamToTensorOp>(
rewriter.getUnknownLoc(), collapseShape.getResultType(),
streamCollapseShape);
loc, collapseShape.getResultType(), streamCollapseShape);
rewriter.replaceAllUsesWith(collapseShape, resultTensor);

results.push_back(sourceStream);
Expand Down

0 comments on commit c6580f2

Please sign in to comment.