diff --git a/include/scalehls/Dialect/HLS/IR/HLSOps.td b/include/scalehls/Dialect/HLS/IR/HLSOps.td index 7db85685..9c5cbddf 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSOps.td +++ b/include/scalehls/Dialect/HLS/IR/HLSOps.td @@ -107,6 +107,10 @@ def YieldOp : HLSOp<"yield", [NoMemoryEffect, ReturnLike, Terminator, let builders = [OpBuilder<(ins), "build($_builder, $_state, std::nullopt);">]; } +//===----------------------------------------------------------------------===// +// Stream Operations +//===----------------------------------------------------------------------===// + def TensorInitOp : HLSOp<"tensor_init", [NoMemoryEffect]> { let summary = "Initiate a tensor"; @@ -133,6 +137,7 @@ def TensorToStreamOp : HLSOp<"tensor_to_stream",[NoMemoryEffect]> { }]; let hasVerifier = 1; + let hasFolder = 1; } def StreamToTensorOp : HLSOp<"stream_to_tensor", [NoMemoryEffect]> { @@ -190,6 +195,54 @@ def StreamWriteOp : HLSOp<"stream_write", [ let hasVerifier = 1; } +def StreamExpandShapeOp : HLSOp<"stream_expand_shape", [NoMemoryEffect]> { + let summary = [{}]; + + let arguments = (ins AnyStream:$input, IndexListArrayAttr:$reassociation); + let results = (outs AnyStream:$output); + let assemblyFormat = [{ + $input attr-dict `:` functional-type($input, $output) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + SmallVector getReassociationIndices() { + SmallVector reassociationIndices; + for (auto attr : getReassociation()) + reassociationIndices.push_back(llvm::to_vector<2>( + llvm::map_range(::llvm::cast(attr), [&](Attribute indexAttr) { + return ::llvm::cast(indexAttr).getInt(); + }))); + return reassociationIndices; + } + }]; +} + +def StreamCollapseShapeOp : HLSOp<"stream_collapse_shape", [NoMemoryEffect]> { + let summary = [{}]; + + let arguments = (ins AnyStream:$input, IndexListArrayAttr:$reassociation); + let results = (outs AnyStream:$output); + let assemblyFormat = [{ + $input attr-dict `:` functional-type($input, $output) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + SmallVector getReassociationIndices() { + SmallVector reassociationIndices; + for (auto attr : getReassociation()) + reassociationIndices.push_back(llvm::to_vector<2>( + llvm::map_range(::llvm::cast(attr), [&](Attribute indexAttr) { + return ::llvm::cast(indexAttr).getInt(); + }))); + return reassociationIndices; + } + }]; +} + def StreamElementChunkOp : HLSOp<"stream_element_chunk", [NoMemoryEffect]> { let summary = [{ Chunk the input stream channel's element to smaller elements and pass to the @@ -224,8 +277,7 @@ def StreamElementConcatOp : HLSOp<"stream_element_concat", [NoMemoryEffect]> { def StreamCastOp : HLSOp<"stream_cast", [NoMemoryEffect]> { let summary = [{ - Cast a stream channel to another type. The input and output stream channel - must be compatible with each other. + Cast a stream channel to another type. }]; let arguments = (ins AnyStream:$input); diff --git a/include/scalehls/Dialect/HLS/IR/HLSTypes.td b/include/scalehls/Dialect/HLS/IR/HLSTypes.td index 879bc04e..45ea47af 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSTypes.td +++ b/include/scalehls/Dialect/HLS/IR/HLSTypes.td @@ -62,6 +62,8 @@ def StreamType : HLSType<"Stream", [ShapedTypeInterface]> { }]> ]; + let genVerifyDecl = 1; + let extraClassDeclaration = [{ /// Returns if this type is ranked. bool hasRank() const { return true; } @@ -76,11 +78,17 @@ def StreamType : HLSType<"Stream", [ShapedTypeInterface]> { using ShapedType::Trait::getNumElements; using ShapedType::Trait::getDimSize; + /// Infer the integral shape of the data this stream type represents. + SmallVector inferIntegralShape() const; + /// Return whether the "other" stream type is compatible with this stream /// type. By being compatible, it means that the two stream types have the /// element type and iteration order, but not necessarily the same iteration - /// shape and layout. Compatible stream types can be casted to each other. - bool isCompatibleWith(StreamType other); + /// shape and layout. + bool isCompatibleWith(StreamType other) const; + + /// Return whether this stream type can be converted to the "tensor" type. + bool isConvertableWith(RankedTensorType tensor) const; }]; } diff --git a/lib/Dialect/HLS/IR/HLSOps.cpp b/lib/Dialect/HLS/IR/HLSOps.cpp index 8a9a58da..99d37c23 100644 --- a/lib/Dialect/HLS/IR/HLSOps.cpp +++ b/lib/Dialect/HLS/IR/HLSOps.cpp @@ -201,13 +201,28 @@ LogicalResult TensorInitOp::verify() { // TensorToStreamOp //===----------------------------------------------------------------------===// -LogicalResult TensorToStreamOp::verify() { return success(); } +LogicalResult TensorToStreamOp::verify() { + if (getStream().getType().isConvertableWith(getTensor().getType())) + return emitOpError("stream type is not convertable with tensor type"); + return success(); +} + +OpFoldResult TensorToStreamOp::fold(FoldAdaptor adaptor) { + if (auto streamToTensor = getTensor().getDefiningOp()) + if (streamToTensor.getStream().getType() == getStream().getType()) + return streamToTensor.getStream(); + return {}; +} //===----------------------------------------------------------------------===// // StreamToTensorOp //===----------------------------------------------------------------------===// -LogicalResult StreamToTensorOp::verify() { return success(); } +LogicalResult StreamToTensorOp::verify() { + if (getStream().getType().isConvertableWith(getTensor().getType())) + return emitOpError("stream type is not convertable with tensor type"); + return success(); +} //===----------------------------------------------------------------------===// // StreamOp @@ -269,6 +284,18 @@ LogicalResult StreamElementChunkOp::verify() { return success(); } LogicalResult StreamElementConcatOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// StreamExpandShapeOp +//===----------------------------------------------------------------------===// + +LogicalResult StreamExpandShapeOp::verify() { return success(); } + +//===----------------------------------------------------------------------===// +// StreamCollapseShapeOp +//===----------------------------------------------------------------------===// + +LogicalResult StreamCollapseShapeOp::verify() { return success(); } + //===----------------------------------------------------------------------===// // StreamCastOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/HLS/IR/HLSTypes.cpp b/lib/Dialect/HLS/IR/HLSTypes.cpp index f3fbe1b7..c3836727 100644 --- a/lib/Dialect/HLS/IR/HLSTypes.cpp +++ b/lib/Dialect/HLS/IR/HLSTypes.cpp @@ -9,6 +9,24 @@ using namespace mlir; using namespace scalehls; using namespace hls; +LogicalResult +hls::StreamType::verify(function_ref emitError, + Type elementType, ArrayRef shape, + MemRefLayoutAttrInterface iterLayout, int64_t depth) { + if (iterLayout.getAffineMap().getNumSymbols()) + return emitError() << "iteration layout cannot have symbols"; + if (shape.size() != iterLayout.getAffineMap().getNumDims()) + return emitError() << "shape size and iteration layout mismatch"; + if (auto shapedElementType = llvm::dyn_cast(elementType)) { + if (!shapedElementType.hasRank()) + return emitError() << "element type must be ranked"; + if (iterLayout.getAffineMap().getNumResults() != + shapedElementType.getRank()) + return emitError() << "iteration layout and element type rank mismatch"; + } + return success(); +} + /// Clone this type with the given shape and element type. If the provided /// shape is `std::nullopt`, the current shape of the type is used. StreamType hls::StreamType::cloneWith(std::optional> shape, @@ -17,12 +35,48 @@ StreamType hls::StreamType::cloneWith(std::optional> shape, getIterLayout(), getDepth()); } +/// Infer the integral shape of the data this stream type represents. +SmallVector hls::StreamType::inferIntegralShape() const { + SmallVector iterSizeInputs; + for (auto iterSize : getShape()) + iterSizeInputs.push_back(getAffineConstantExpr(iterSize, getContext())); + auto iterSizeMap = getIterLayout().getAffineMap().replaceDimsAndSymbols( + iterSizeInputs, {}, 0, 0); + auto shapedElementType = llvm::dyn_cast(getElementType()); + + SmallVector integralShape; + for (auto [index, iterSize] : llvm::enumerate(iterSizeMap.getResults())) { + auto constIterSize = llvm::dyn_cast(iterSize); + assert(constIterSize && "non-constant size in the iteration layout"); + + if (!shapedElementType) + integralShape.push_back(constIterSize.getValue()); + else + integralShape.push_back(constIterSize.getValue() * + shapedElementType.getDimSize(index)); + } + return integralShape; +} + /// Return whether the "other" stream type is compatible with this stream type. /// By being compatible, it means that the two stream types have the element /// type and iteration order, but not necessarily the same iteration shape and -/// layout. Compatible stream types can be casted to each other. -bool hls::StreamType::isCompatibleWith(StreamType other) { +/// layout. +bool hls::StreamType::isCompatibleWith(StreamType other) const { if (*this == other) return true; return false; } + +/// Return whether this stream type can be converted to the "tensor" type. +bool hls::StreamType::isConvertableWith(RankedTensorType tensor) const { + if (tensor.getElementType() != getElementType()) + return false; + if (tensor.getRank() != getRank()) + return false; + if (llvm::any_of(llvm::zip(tensor.getShape(), getShape()), [](auto pair) { + return std::get<0>(pair) != std::get<1>(pair); + })) + return false; + return true; +}