Skip to content

Commit

Permalink
Add LowerTensorToStreamOp
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Feb 2, 2024
1 parent 7040069 commit 6a875ef
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 39 deletions.
19 changes: 19 additions & 0 deletions include/scalehls/Dialect/HLS/TransformOps/HLSTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,23 @@ def HLSConvertExtractSliceToStreamOp : Op<Transform_Dialect,
}];
}

def HLSLowerTensorToStreamOp : Op<Transform_Dialect,
"hls.lower_tensor_to_stream",
[FunctionalStyleTransformOpTrait, TransformEachOpTrait,
TransformOpInterface, MemoryEffectsOpInterface]> {
let description = [{}];
let arguments = (ins
Transform_ConcreteOpType<"hls.tensor_to_stream">:$tensor_to_stream);
let assemblyFormat = [{
$tensor_to_stream attr-dict `:` type(operands)
}];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::scalehls::hls::TensorToStreamOp tensorToStream,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // SCALEHLS_DIALECT_HLS_TRANSFORMOPS_HLSTRANSFORMOPS_TD
93 changes: 54 additions & 39 deletions lib/Dialect/HLS/TransformOps/HLSTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ getUntiledOperandAndSurroundingLoops(OpOperand *source,

DiagnosedSilenceableFailure
transform::HLSConvertExtractSliceToTensorInitOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::ExtractSliceOp target,
transform::TransformRewriter &rewriter, tensor::ExtractSliceOp extractSlice,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
auto untiledUse =
getUntiledOperandAndSurroundingLoops(&target.getSourceMutable());
getUntiledOperandAndSurroundingLoops(&extractSlice.getSourceMutable());
auto tensorInit = untiledUse->get().getDefiningOp<hls::TensorInitOp>();
if (!tensorInit)
return emitDefaultSilenceableFailure(target);
return emitDefaultSilenceableFailure(extractSlice);

rewriter.setInsertionPoint(target);
rewriter.setInsertionPoint(extractSlice);
auto localTensorInit = rewriter.replaceOpWithNewOp<hls::TensorInitOp>(
target, target.getType(), tensorInit.getInitValue());
extractSlice, extractSlice.getType(), tensorInit.getInitValue());
results.push_back(localTensorInit);
return DiagnosedSilenceableFailure::success();
}
Expand All @@ -73,37 +73,37 @@ static scf::ForOp getAssociatedLoop(Value value) {
}

DiagnosedSilenceableFailure transform::HLSDemoteExtractSliceOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::ExtractSliceOp target,
transform::TransformRewriter &rewriter, tensor::ExtractSliceOp extractSlice,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// We first check whether the extract_slice op's source is an iter_args.
auto sourceArg = dyn_cast<BlockArgument>(target.getSource());
auto sourceArg = dyn_cast<BlockArgument>(extractSlice.getSource());
if (sourceArg && isa<scf::ForOp>(sourceArg.getOwner()->getParentOp()))
return emitDefaultSilenceableFailure(target);
results.push_back(target);
return emitDefaultSilenceableFailure(extractSlice);
results.push_back(extractSlice);

// We then check if all offsets are loop induction variables, and collect
// them into a set.
llvm::SmallDenseSet<Value> offsets;
for (auto offset : target.getMixedOffsets())
for (auto offset : extractSlice.getMixedOffsets())
if (auto offsetValue = offset.dyn_cast<Value>()) {
// Here, we need to handle the case where the offset is defined by an
// affine.apply op.
if (auto apply = offsetValue.getDefiningOp<affine::AffineApplyOp>()) {
for (auto operand : apply.getOperands()) {
if (!getAssociatedLoop(operand))
return emitDefaultSilenceableFailure(target);
return emitDefaultSilenceableFailure(extractSlice);
offsets.insert(operand);
}
} else {
if (!getAssociatedLoop(offsetValue))
return emitDefaultSilenceableFailure(target);
return emitDefaultSilenceableFailure(extractSlice);
offsets.insert(offsetValue);
}
}

// Then, we find the outermost loop that does not contain any of the offsets.
Operation *insertBefore = target;
Operation *insertBefore = extractSlice;
while (auto loop = insertBefore->getParentOfType<scf::ForOp>()) {
if (!offsets.count(loop.getInductionVar()))
insertBefore = loop;
Expand All @@ -112,8 +112,8 @@ DiagnosedSilenceableFailure transform::HLSDemoteExtractSliceOp::applyToOne(
}

// Finally, we move the extract_slice op before the outermost loop.
if (insertBefore != target)
target->moveBefore(insertBefore);
if (insertBefore != extractSlice)
extractSlice->moveBefore(insertBefore);
return DiagnosedSilenceableFailure::success();
}

Expand Down Expand Up @@ -180,39 +180,39 @@ getIterationAffineMap(OpTy target, const SmallVector<scf::ForOp> &loops) {

DiagnosedSilenceableFailure
transform::HLSConvertInsertSliceToStreamOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::InsertSliceOp target,
transform::TransformRewriter &rewriter, tensor::InsertSliceOp insertSlice,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Check if the destination tensor of the insert_slice op has only one use,
// which means no other operations have effect on the tensor.
if (!target.getDest().hasOneUse())
return emitDefaultSilenceableFailure(target);
if (!insertSlice.getDest().hasOneUse())
return emitDefaultSilenceableFailure(insertSlice);

// Collect the surrounding loops of the insert_slice op.
SmallVector<scf::ForOp> loops;
auto untiledUse =
getUntiledOperandAndSurroundingLoops(&target.getDestMutable(), &loops);
if (untiledUse == &target.getDestMutable())
return emitDefaultSilenceableFailure(target);
auto untiledUse = getUntiledOperandAndSurroundingLoops(
&insertSlice.getDestMutable(), &loops);
if (untiledUse == &insertSlice.getDestMutable())
return emitDefaultSilenceableFailure(insertSlice);

// Collect the iteration shape and affine map of the streaming channel.
auto iterShape = getTripCounts(loops);
auto iterMap = getIterationAffineMap(target, loops);
auto iterMap = getIterationAffineMap(insertSlice, loops);
if (!iterShape || !iterMap)
return emitDefaultSilenceableFailure(target);
return emitDefaultSilenceableFailure(insertSlice);

// Create the streaming channel.
rewriter.setInsertionPoint(loops.front());
auto channelType =
hls::StreamType::get(target.getSourceType(), *iterShape, *iterMap,
target.getDestType().getNumElements());
hls::StreamType::get(insertSlice.getSourceType(), *iterShape, *iterMap,
insertSlice.getDestType().getNumElements());
auto channel =
rewriter.create<hls::StreamOp>(rewriter.getUnknownLoc(), channelType);

// Create the stream_write op.
rewriter.setInsertionPoint(target);
rewriter.setInsertionPoint(insertSlice);
auto channelWrite = rewriter.create<hls::StreamWriteOp>(
rewriter.getUnknownLoc(), channel, target.getSource());
rewriter.getUnknownLoc(), channel, insertSlice.getSource());

// Create the stream_to_tensor op.
rewriter.setInsertionPointAfter(loops.front());
Expand Down Expand Up @@ -245,38 +245,53 @@ static SmallVector<scf::ForOp> getSurroundingLoops(Operation *target,

DiagnosedSilenceableFailure
transform::HLSConvertExtractSliceToStreamOp::applyToOne(
transform::TransformRewriter &rewriter, tensor::ExtractSliceOp target,
transform::TransformRewriter &rewriter, tensor::ExtractSliceOp extractSlice,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Collect the surrounding loops of the extract_slice op.
auto loops = getSurroundingLoops(target, target.getSource().getParentBlock());
auto loops = getSurroundingLoops(extractSlice,
extractSlice.getSource().getParentBlock());

// Collect the iteration shape and affine map of the streaming channel.
auto iterShape = getTripCounts(loops);
auto iterMap = getIterationAffineMap(target, loops);
auto iterMap = getIterationAffineMap(extractSlice, loops);
if (!iterShape || !iterMap)
return emitDefaultSilenceableFailure(target);
return emitDefaultSilenceableFailure(extractSlice);

// Create the tensor_to_stream op.
rewriter.setInsertionPointAfterValue(target.getSource());
rewriter.setInsertionPointAfterValue(extractSlice.getSource());
auto channelType =
hls::StreamType::get(target.getResultType(), *iterShape, *iterMap,
target.getSourceType().getNumElements());
hls::StreamType::get(extractSlice.getResultType(), *iterShape, *iterMap,
extractSlice.getSourceType().getNumElements());
auto channel = rewriter.create<hls::TensorToStreamOp>(
rewriter.getUnknownLoc(), channelType, target.getSource());
rewriter.getUnknownLoc(), channelType, extractSlice.getSource());

// Create the stream_read op.
rewriter.setInsertionPoint(target);
rewriter.setInsertionPoint(extractSlice);
auto channelRead = rewriter.create<hls::StreamReadOp>(
rewriter.getUnknownLoc(), target.getResultType(), channel);
rewriter.getUnknownLoc(), extractSlice.getResultType(), channel);

// Create the stream_to_tensor op.
rewriter.replaceAllUsesWith(target.getResult(), channelRead.getResult());
rewriter.replaceAllUsesWith(extractSlice.getResult(),
channelRead.getResult());
results.push_back(channel);
results.push_back(channelRead);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// HLSLowerTensorToStreamOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::HLSLowerTensorToStreamOp::applyToOne(
transform::TransformRewriter &rewriter,
hls::TensorToStreamOp tensorToStream,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {

return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 6a875ef

Please sign in to comment.