Skip to content

Commit

Permalink
Update transform ops
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Feb 22, 2024
1 parent a43b01f commit 7ff6a09
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 133 deletions.
2 changes: 1 addition & 1 deletion include/scalehls/Dialect/HLS/IR/HLSTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SlidingType<string name, list<Trait> traits = []>
TypeBuilderWithInferredContext<(ins "ArrayRef<int64_t>":$shape,
"mlir::Type":$elementType, "ArrayRef<int64_t>":$iterTripCounts,
"ArrayRef<int64_t>":$iterSteps, "mlir::AffineMap":$iterMap), [{
return $_get(elementType.getContext(), shape, elementType, iterTripCounts,
return $_get(elementType.getContext(), elementType, shape, iterTripCounts,
iterSteps, iterMap);
}]>
];
Expand Down
251 changes: 129 additions & 122 deletions lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,21 @@ using namespace affine;
// HLSConvertExtractSliceToTensorInitOp
//===----------------------------------------------------------------------===//

static OpOperand *
getUntiledOperandAndSurroundingLoops(OpOperand *source,
SmallVector<scf::ForOp> *loops = nullptr) {
SmallVector<scf::ForOp> reverseLoops;
while (auto arg = dyn_cast<BlockArgument>(source->get())) {
if (auto loop = dyn_cast<scf::ForOp>(arg.getOwner()->getParentOp())) {
source = loop.getTiedLoopInit(arg);
reverseLoops.push_back(loop);
} else
break;
}
if (loops)
*loops = {reverseLoops.rbegin(), reverseLoops.rend()};
return source;
}

DiagnosedSilenceableFailure
transform::HLSConvertExtractSliceToTensorInitOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::ExtractSliceOp extractSlice,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
auto untiledUse =
getUntiledOperandAndSurroundingLoops(&extractSlice.getSourceMutable());
auto tensorInit = untiledUse->get().getDefiningOp<hls::TensorInitOp>();
if (!tensorInit)
return emitDefaultSilenceableFailure(extractSlice);
auto init = getUntiledSource(extractSlice.getSource())
.getDefiningOp<hls::TensorInitOp>();
if (!init)
return emitDefaultSilenceableFailure(extractSlice)
<< "extract_slice source is not initialized by a tensor_init op";

rewriter.setInsertionPoint(extractSlice);
auto localTensorInit = rewriter.replaceOpWithNewOp<hls::TensorInitOp>(
extractSlice, extractSlice.getType(), tensorInit.getInitValue());
results.push_back(localTensorInit);
auto localInit = rewriter.replaceOpWithNewOp<hls::TensorInitOp>(
extractSlice, extractSlice.getType(), init.getInitValue());
results.push_back(localInit);
return DiagnosedSilenceableFailure::success();
}

Expand Down Expand Up @@ -165,46 +149,68 @@ transform::HLSConvertInsertSliceToStreamOp::applyToOne(
// Check if the destination tensor of the insert_slice op has only one use,
// which means no other operations have effect on the tensor.
if (!insertSlice.getDest().hasOneUse())
return emitDefaultSilenceableFailure(insertSlice);
return emitDefaultSilenceableFailure(insertSlice)
<< "dest tensor of insert_slice has more than one use";

// Collect the surrounding loops of the insert_slice op.
SmallVector<scf::ForOp> loops;
auto untiledUse = getUntiledOperandAndSurroundingLoops(
&insertSlice.getDestMutable(), &loops);
if (untiledUse == &insertSlice.getDestMutable())
return emitDefaultSilenceableFailure(insertSlice);
auto untiledDest = getUntiledSource(insertSlice.getDest());
auto loops = getSurroundingLoops(insertSlice, untiledDest.getParentBlock());
if (loops.empty())
return emitDefaultSilenceableFailure(insertSlice)
<< "no surrounding loops found for insert_slice";

// Find the tensor_init op that initializes the destination tensor.
auto tensorInit = untiledDest.getDefiningOp<hls::TensorInitOp>();
if (!tensorInit)
return emitDefaultSilenceableFailure(insertSlice)
<< "dest tensor is not initialized by a tensor_init op";

// Collect the iteration shape and affine map of the streaming channel.
// Collect the iteration shape and affine map of the sliding tensor.
auto iterTripCounts = getLoopTripCounts(loops);
auto iterSteps = getLoopSteps(loops);
auto iterMap = getIterationAffineMap(insertSlice, loops);
if (!iterTripCounts || !iterSteps || !iterMap)
return emitDefaultSilenceableFailure(insertSlice);

// Create the streaming channel.
// Construct the sliding tensor type.
auto tensorType = tensorInit.getType();
auto sTensorType = hls::SlidingTensorType::get(
insertSlice.getResultType().getShape(), tensorType.getElementType(),
*iterTripCounts, *iterSteps, *iterMap);

// Create the sliding tensor.
auto loc = insertSlice.getLoc();
rewriter.setInsertionPoint(loops.front());
auto channelType = hls::StreamType::get(
insertSlice.getSourceType(), *iterTripCounts, *iterSteps, *iterMap,
insertSlice.getDestType().getNumElements());
auto channel =
rewriter.create<hls::StreamOp>(rewriter.getUnknownLoc(), channelType);
auto sTensorInit =
rewriter.create<hls::SlidingTensorInitOp>(loc, sTensorType);

// Update the loop iteration arguments.
auto index = cast<BlockArgument>(insertSlice.getDest()).getArgNumber();
loops.front().setOperand(index, sTensorInit.getResult());
for (auto loop : loops) {
loop.getRegionIterArg(index).setType(sTensorType);
loop.getResult(index).setType(sTensorType);
}

// Create the stream_write op.
// Create the stensor_push op.
rewriter.setInsertionPoint(insertSlice);
auto channelWrite = rewriter.create<hls::StreamWriteOp>(
rewriter.getUnknownLoc(), channel, insertSlice.getSource());
auto sTensorPush = rewriter.create<hls::SlidingTensorPushOp>(
loc, sTensorType, insertSlice.getDest(), insertSlice.getSource());
rewriter.replaceOp(insertSlice, sTensorPush.getResult());

// Create the stream_to_tensor op.
rewriter.setInsertionPointAfter(loops.front());
auto tensorToReplace = loops.front().getTiedLoopResult(untiledUse);
auto channelTensor = rewriter.create<hls::StreamToTensorOp>(
rewriter.getUnknownLoc(), tensorToReplace.getType(), channel);
auto sTensorResult = loops.front().getResult(index);
auto tensorResult = rewriter.create<hls::SlidingTensorToTensorOp>(
loc, tensorType, sTensorResult);
rewriter.replaceUsesWithIf(
sTensorResult, tensorResult.getTensor(),
[&](OpOperand &use) { return use.getOwner() != tensorResult; });

rewriter.replaceAllUsesWith(tensorToReplace, channelTensor);
rewriter.replaceOp(insertSlice, insertSlice.getDest());
results.push_back(channel);
results.push_back(channelWrite);
results.push_back(channelTensor);
results.push_back(sTensorInit);
results.push_back(sTensorPush);
results.push_back(tensorResult);
return DiagnosedSilenceableFailure::success();
}

Expand All @@ -217,35 +223,35 @@ transform::HLSConvertExtractSliceToStreamOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::ExtractSliceOp extractSlice,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Collect the surrounding loops of the extract_slice op.
auto loops = getSurroundingLoops(extractSlice,
extractSlice.getSource().getParentBlock());

// Collect the iteration shape and affine map of the streaming channel.
auto iterTripCounts = getLoopTripCounts(loops);
auto iterSteps = getLoopSteps(loops);
auto iterMap = getIterationAffineMap(extractSlice, loops);
if (!iterTripCounts || !iterSteps || !iterMap)
return emitDefaultSilenceableFailure(extractSlice);

// Create the tensor_to_stream op.
rewriter.setInsertionPointAfterValue(extractSlice.getSource());
auto channelType = hls::StreamType::get(
extractSlice.getResultType(), *iterTripCounts, *iterSteps, *iterMap,
extractSlice.getSourceType().getNumElements());
auto channel = rewriter.create<hls::TensorToStreamOp>(
rewriter.getUnknownLoc(), channelType, extractSlice.getSource());

// Create the stream_read op.
rewriter.setInsertionPoint(extractSlice);
auto channelRead = rewriter.create<hls::StreamReadOp>(
rewriter.getUnknownLoc(), extractSlice.getResultType(), channel);

// Create the stream_to_tensor op.
rewriter.replaceAllUsesWith(extractSlice.getResult(),
channelRead.getResult());
results.push_back(channel);
results.push_back(channelRead);
// // Collect the surrounding loops of the extract_slice op.
// auto loops = getSurroundingLoops(extractSlice,
// extractSlice.getSource().getParentBlock());

// // Collect the iteration shape and affine map of the streaming channel.
// auto iterTripCounts = getLoopTripCounts(loops);
// auto iterSteps = getLoopSteps(loops);
// auto iterMap = getIterationAffineMap(extractSlice, loops);
// if (!iterTripCounts || !iterSteps || !iterMap)
// return emitDefaultSilenceableFailure(extractSlice);

// // Create the tensor_to_stream op.
// rewriter.setInsertionPointAfterValue(extractSlice.getSource());
// auto channelType = hls::StreamType::get(
// extractSlice.getResultType(), *iterTripCounts, *iterSteps, *iterMap,
// extractSlice.getSourceType().getNumElements());
// auto channel = rewriter.create<hls::TensorToStreamOp>(
// rewriter.getUnknownLoc(), channelType, extractSlice.getSource());

// // Create the stream_read op.
// rewriter.setInsertionPoint(extractSlice);
// auto channelRead = rewriter.create<hls::StreamReadOp>(
// rewriter.getUnknownLoc(), extractSlice.getResultType(), channel);

// // Create the stream_to_tensor op.
// rewriter.replaceAllUsesWith(extractSlice.getResult(),
// channelRead.getResult());
// results.push_back(channel);
// results.push_back(channelRead);
return DiagnosedSilenceableFailure::success();
}

Expand Down Expand Up @@ -327,29 +333,29 @@ transform::HLSConvertExpandShapeToStreamOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::ExpandShapeOp expandShape,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Construct the source and result stream types.
auto streamTypes = getReassociateStreamTypes(
expandShape, expandShape.getSrcType(), expandShape.getResultType(),
getInputElementShape(), getOutputElementShape());
if (!streamTypes)
return emitDefaultSilenceableFailure(expandShape);

// Convert the expand_shape op to stream ops and replace its uses.
auto loc = expandShape.getLoc();
rewriter.setInsertionPoint(expandShape);
auto sourceStream = rewriter.create<hls::TensorToStreamOp>(
loc, streamTypes->first, expandShape.getSrc());
auto streamReassociate = rewriter.create<hls::StreamReassociateOp>(
loc, streamTypes->second, sourceStream, /*expandShape=*/true,
expandShape.getReassociation(), /*expandIteration=*/true,
expandShape.getReassociation());
auto resultTensor = rewriter.create<hls::StreamToTensorOp>(
loc, expandShape.getResultType(), streamReassociate);
rewriter.replaceAllUsesWith(expandShape, resultTensor);

results.push_back(sourceStream);
results.push_back(streamReassociate);
results.push_back(resultTensor);
// // Construct the source and result stream types.
// auto streamTypes = getReassociateStreamTypes(
// expandShape, expandShape.getSrcType(), expandShape.getResultType(),
// getInputElementShape(), getOutputElementShape());
// if (!streamTypes)
// return emitDefaultSilenceableFailure(expandShape);

// // Convert the expand_shape op to stream ops and replace its uses.
// auto loc = expandShape.getLoc();
// rewriter.setInsertionPoint(expandShape);
// auto sourceStream = rewriter.create<hls::TensorToStreamOp>(
// loc, streamTypes->first, expandShape.getSrc());
// auto streamReassociate = rewriter.create<hls::StreamReassociateOp>(
// loc, streamTypes->second, sourceStream, /*expandShape=*/true,
// expandShape.getReassociation(), /*expandIteration=*/true,
// expandShape.getReassociation());
// auto resultTensor = rewriter.create<hls::StreamToTensorOp>(
// loc, expandShape.getResultType(), streamReassociate);
// rewriter.replaceAllUsesWith(expandShape, resultTensor);

// results.push_back(sourceStream);
// results.push_back(streamReassociate);
// results.push_back(resultTensor);
return DiagnosedSilenceableFailure::success();
}

Expand All @@ -367,29 +373,30 @@ transform::HLSConvertCollapseShapeToStreamOp::applyToOne(
tensor::CollapseShapeOp collapseShape,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Construct the source and result stream types.
auto streamTypes = getReassociateStreamTypes(
collapseShape, collapseShape.getResultType(), collapseShape.getSrcType(),
getOutputElementShape(), getInputElementShape());
if (!streamTypes)
return emitDefaultSilenceableFailure(collapseShape);

// Convert the expand_shape op to stream ops and replace its uses.
auto loc = collapseShape.getLoc();
rewriter.setInsertionPoint(collapseShape);
auto sourceStream = rewriter.create<hls::TensorToStreamOp>(
loc, streamTypes->second, collapseShape.getSrc());
auto streamReassociate = rewriter.create<hls::StreamReassociateOp>(
loc, streamTypes->first, sourceStream, /*expandShape=*/false,
collapseShape.getReassociation(), /*expandIteration=*/false,
collapseShape.getReassociation());
auto resultTensor = rewriter.create<hls::StreamToTensorOp>(
loc, collapseShape.getResultType(), streamReassociate);
rewriter.replaceAllUsesWith(collapseShape, resultTensor);

results.push_back(sourceStream);
results.push_back(streamReassociate);
results.push_back(resultTensor);
// // Construct the source and result stream types.
// auto streamTypes = getReassociateStreamTypes(
// collapseShape, collapseShape.getResultType(),
// collapseShape.getSrcType(), getOutputElementShape(),
// getInputElementShape());
// if (!streamTypes)
// return emitDefaultSilenceableFailure(collapseShape);

// // Convert the expand_shape op to stream ops and replace its uses.
// auto loc = collapseShape.getLoc();
// rewriter.setInsertionPoint(collapseShape);
// auto sourceStream = rewriter.create<hls::TensorToStreamOp>(
// loc, streamTypes->second, collapseShape.getSrc());
// auto streamReassociate = rewriter.create<hls::StreamReassociateOp>(
// loc, streamTypes->first, sourceStream, /*expandShape=*/false,
// collapseShape.getReassociation(), /*expandIteration=*/false,
// collapseShape.getReassociation());
// auto resultTensor = rewriter.create<hls::StreamToTensorOp>(
// loc, collapseShape.getResultType(), streamReassociate);
// rewriter.replaceAllUsesWith(collapseShape, resultTensor);

// results.push_back(sourceStream);
// results.push_back(streamReassociate);
// results.push_back(resultTensor);
return DiagnosedSilenceableFailure::success();
}

Expand Down
14 changes: 9 additions & 5 deletions lib/Dialect/HLS/Transforms/MaterializeStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using namespace mlir;
using namespace scalehls;
using namespace hls;

/*
/// Return the offsets, sizes, and strides of a slice given the loop induction
/// variables "ivs", the index expressions "indexExprs", the element shape
/// "elementShape", and the packing flag "packing". If "packing" is true, the
Expand Down Expand Up @@ -357,6 +359,8 @@ struct FoldPackOpIntoConstantOp : public OpRewritePattern<tensor::PackOp> {
};
} // namespace
*/

namespace {
struct MaterializeStream : public MaterializeStreamBase<MaterializeStream> {
MaterializeStream() = default;
Expand All @@ -369,11 +373,11 @@ struct MaterializeStream : public MaterializeStreamBase<MaterializeStream> {
auto context = op->getContext();

mlir::RewritePatternSet patterns(context);
patterns.add<LowerTensorToStreamConversionOp>(context, enablePacking);
patterns.add<LowerStreamToTensorConversionOp>(context, enablePacking);
patterns.add<LowerStreamBufferOp>(context, enablePacking);
if (enablePacking)
patterns.add<FoldPackOpIntoConstantOp>(context);
// patterns.add<LowerTensorToStreamConversionOp>(context, enablePacking);
// patterns.add<LowerStreamToTensorConversionOp>(context, enablePacking);
// patterns.add<LowerStreamBufferOp>(context, enablePacking);
// if (enablePacking)
// patterns.add<FoldPackOpIntoConstantOp>(context);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
Expand Down
6 changes: 5 additions & 1 deletion lib/Dialect/HLS/Transforms/ReduceTensorToStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using namespace mlir;
using namespace scalehls;
using namespace hls;

/*
namespace {
struct ConvertToStreamBuffer : public OpRewritePattern<hls::TensorToStreamOp> {
using OpRewritePattern<hls::TensorToStreamOp>::OpRewritePattern;
Expand Down Expand Up @@ -94,6 +96,8 @@ struct ConvertToStreamBuffer : public OpRewritePattern<hls::TensorToStreamOp> {
};
} // namespace
*/

namespace {
struct ReduceTensorToStream
: public ReduceTensorToStreamBase<ReduceTensorToStream> {
Expand All @@ -102,7 +106,7 @@ struct ReduceTensorToStream
auto context = op->getContext();

mlir::RewritePatternSet patterns(context);
patterns.add<ConvertToStreamBuffer>(context);
// patterns.add<ConvertToStreamBuffer>(context);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
Expand Down
Loading

0 comments on commit 7ff6a09

Please sign in to comment.