From 6a875ef15c215c4f802caad9aee1862aa4fd67e7 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Fri, 2 Feb 2024 15:34:13 -0600 Subject: [PATCH] Add LowerTensorToStreamOp --- .../HLS/TransformOps/HLSTransformOps.td | 19 ++++ .../HLS/TransformOps/HLSTransformOps.cpp | 93 +++++++++++-------- 2 files changed, 73 insertions(+), 39 deletions(-) diff --git a/include/scalehls/Dialect/HLS/TransformOps/HLSTransformOps.td b/include/scalehls/Dialect/HLS/TransformOps/HLSTransformOps.td index a0407bfe..cce5a0f0 100644 --- a/include/scalehls/Dialect/HLS/TransformOps/HLSTransformOps.td +++ b/include/scalehls/Dialect/HLS/TransformOps/HLSTransformOps.td @@ -112,4 +112,23 @@ def HLSConvertExtractSliceToStreamOp : Op { + let description = [{}]; + let arguments = (ins + Transform_ConcreteOpType<"hls.tensor_to_stream">:$tensor_to_stream); + let assemblyFormat = [{ + $tensor_to_stream attr-dict `:` type(operands) + }]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::scalehls::hls::TensorToStreamOp tensorToStream, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // SCALEHLS_DIALECT_HLS_TRANSFORMOPS_HLSTRANSFORMOPS_TD diff --git a/lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp b/lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp index 111e6da1..50319ba8 100644 --- a/lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp +++ b/lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp @@ -44,18 +44,18 @@ getUntiledOperandAndSurroundingLoops(OpOperand *source, DiagnosedSilenceableFailure transform::HLSConvertExtractSliceToTensorInitOp::applyToOne( - transform::TransformRewriter &rewriter, tensor::ExtractSliceOp target, + transform::TransformRewriter &rewriter, tensor::ExtractSliceOp extractSlice, transform::ApplyToEachResultList &results, transform::TransformState &state) { auto untiledUse = - getUntiledOperandAndSurroundingLoops(&target.getSourceMutable()); + getUntiledOperandAndSurroundingLoops(&extractSlice.getSourceMutable()); auto tensorInit = untiledUse->get().getDefiningOp(); if (!tensorInit) - return emitDefaultSilenceableFailure(target); + return emitDefaultSilenceableFailure(extractSlice); - rewriter.setInsertionPoint(target); + rewriter.setInsertionPoint(extractSlice); auto localTensorInit = rewriter.replaceOpWithNewOp( - target, target.getType(), tensorInit.getInitValue()); + extractSlice, extractSlice.getType(), tensorInit.getInitValue()); results.push_back(localTensorInit); return DiagnosedSilenceableFailure::success(); } @@ -73,37 +73,37 @@ static scf::ForOp getAssociatedLoop(Value value) { } DiagnosedSilenceableFailure transform::HLSDemoteExtractSliceOp::applyToOne( - transform::TransformRewriter &rewriter, tensor::ExtractSliceOp target, + transform::TransformRewriter &rewriter, tensor::ExtractSliceOp extractSlice, transform::ApplyToEachResultList &results, transform::TransformState &state) { // We first check whether the extract_slice op's source is an iter_args. - auto sourceArg = dyn_cast(target.getSource()); + auto sourceArg = dyn_cast(extractSlice.getSource()); if (sourceArg && isa(sourceArg.getOwner()->getParentOp())) - return emitDefaultSilenceableFailure(target); - results.push_back(target); + return emitDefaultSilenceableFailure(extractSlice); + results.push_back(extractSlice); // We then check if all offsets are loop induction variables, and collect // them into a set. llvm::SmallDenseSet offsets; - for (auto offset : target.getMixedOffsets()) + for (auto offset : extractSlice.getMixedOffsets()) if (auto offsetValue = offset.dyn_cast()) { // Here, we need to handle the case where the offset is defined by an // affine.apply op. if (auto apply = offsetValue.getDefiningOp()) { for (auto operand : apply.getOperands()) { if (!getAssociatedLoop(operand)) - return emitDefaultSilenceableFailure(target); + return emitDefaultSilenceableFailure(extractSlice); offsets.insert(operand); } } else { if (!getAssociatedLoop(offsetValue)) - return emitDefaultSilenceableFailure(target); + return emitDefaultSilenceableFailure(extractSlice); offsets.insert(offsetValue); } } // Then, we find the outermost loop that does not contain any of the offsets. - Operation *insertBefore = target; + Operation *insertBefore = extractSlice; while (auto loop = insertBefore->getParentOfType()) { if (!offsets.count(loop.getInductionVar())) insertBefore = loop; @@ -112,8 +112,8 @@ DiagnosedSilenceableFailure transform::HLSDemoteExtractSliceOp::applyToOne( } // Finally, we move the extract_slice op before the outermost loop. - if (insertBefore != target) - target->moveBefore(insertBefore); + if (insertBefore != extractSlice) + extractSlice->moveBefore(insertBefore); return DiagnosedSilenceableFailure::success(); } @@ -180,39 +180,39 @@ getIterationAffineMap(OpTy target, const SmallVector &loops) { DiagnosedSilenceableFailure transform::HLSConvertInsertSliceToStreamOp::applyToOne( - transform::TransformRewriter &rewriter, tensor::InsertSliceOp target, + transform::TransformRewriter &rewriter, tensor::InsertSliceOp insertSlice, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Check if the destination tensor of the insert_slice op has only one use, // which means no other operations have effect on the tensor. - if (!target.getDest().hasOneUse()) - return emitDefaultSilenceableFailure(target); + if (!insertSlice.getDest().hasOneUse()) + return emitDefaultSilenceableFailure(insertSlice); // Collect the surrounding loops of the insert_slice op. SmallVector loops; - auto untiledUse = - getUntiledOperandAndSurroundingLoops(&target.getDestMutable(), &loops); - if (untiledUse == &target.getDestMutable()) - return emitDefaultSilenceableFailure(target); + auto untiledUse = getUntiledOperandAndSurroundingLoops( + &insertSlice.getDestMutable(), &loops); + if (untiledUse == &insertSlice.getDestMutable()) + return emitDefaultSilenceableFailure(insertSlice); // Collect the iteration shape and affine map of the streaming channel. auto iterShape = getTripCounts(loops); - auto iterMap = getIterationAffineMap(target, loops); + auto iterMap = getIterationAffineMap(insertSlice, loops); if (!iterShape || !iterMap) - return emitDefaultSilenceableFailure(target); + return emitDefaultSilenceableFailure(insertSlice); // Create the streaming channel. rewriter.setInsertionPoint(loops.front()); auto channelType = - hls::StreamType::get(target.getSourceType(), *iterShape, *iterMap, - target.getDestType().getNumElements()); + hls::StreamType::get(insertSlice.getSourceType(), *iterShape, *iterMap, + insertSlice.getDestType().getNumElements()); auto channel = rewriter.create(rewriter.getUnknownLoc(), channelType); // Create the stream_write op. - rewriter.setInsertionPoint(target); + rewriter.setInsertionPoint(insertSlice); auto channelWrite = rewriter.create( - rewriter.getUnknownLoc(), channel, target.getSource()); + rewriter.getUnknownLoc(), channel, insertSlice.getSource()); // Create the stream_to_tensor op. rewriter.setInsertionPointAfter(loops.front()); @@ -245,38 +245,53 @@ static SmallVector getSurroundingLoops(Operation *target, DiagnosedSilenceableFailure transform::HLSConvertExtractSliceToStreamOp::applyToOne( - transform::TransformRewriter &rewriter, tensor::ExtractSliceOp target, + transform::TransformRewriter &rewriter, tensor::ExtractSliceOp extractSlice, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Collect the surrounding loops of the extract_slice op. - auto loops = getSurroundingLoops(target, target.getSource().getParentBlock()); + auto loops = getSurroundingLoops(extractSlice, + extractSlice.getSource().getParentBlock()); // Collect the iteration shape and affine map of the streaming channel. auto iterShape = getTripCounts(loops); - auto iterMap = getIterationAffineMap(target, loops); + auto iterMap = getIterationAffineMap(extractSlice, loops); if (!iterShape || !iterMap) - return emitDefaultSilenceableFailure(target); + return emitDefaultSilenceableFailure(extractSlice); // Create the tensor_to_stream op. - rewriter.setInsertionPointAfterValue(target.getSource()); + rewriter.setInsertionPointAfterValue(extractSlice.getSource()); auto channelType = - hls::StreamType::get(target.getResultType(), *iterShape, *iterMap, - target.getSourceType().getNumElements()); + hls::StreamType::get(extractSlice.getResultType(), *iterShape, *iterMap, + extractSlice.getSourceType().getNumElements()); auto channel = rewriter.create( - rewriter.getUnknownLoc(), channelType, target.getSource()); + rewriter.getUnknownLoc(), channelType, extractSlice.getSource()); // Create the stream_read op. - rewriter.setInsertionPoint(target); + rewriter.setInsertionPoint(extractSlice); auto channelRead = rewriter.create( - rewriter.getUnknownLoc(), target.getResultType(), channel); + rewriter.getUnknownLoc(), extractSlice.getResultType(), channel); // Create the stream_to_tensor op. - rewriter.replaceAllUsesWith(target.getResult(), channelRead.getResult()); + rewriter.replaceAllUsesWith(extractSlice.getResult(), + channelRead.getResult()); results.push_back(channel); results.push_back(channelRead); return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// HLSLowerTensorToStreamOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::HLSLowerTensorToStreamOp::applyToOne( + transform::TransformRewriter &rewriter, + hls::TensorToStreamOp tensorToStream, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===//