Skip to content

Commit

Permalink
Add StreamExpand/CollapseShapeOp; Add several verifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Feb 2, 2024
1 parent be0721b commit 7040069
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 8 deletions.
56 changes: 54 additions & 2 deletions include/scalehls/Dialect/HLS/IR/HLSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def YieldOp : HLSOp<"yield", [NoMemoryEffect, ReturnLike, Terminator,
let builders = [OpBuilder<(ins), "build($_builder, $_state, std::nullopt);">];
}

//===----------------------------------------------------------------------===//
// Stream Operations
//===----------------------------------------------------------------------===//

def TensorInitOp : HLSOp<"tensor_init", [NoMemoryEffect]> {
let summary = "Initiate a tensor";

Expand All @@ -133,6 +137,7 @@ def TensorToStreamOp : HLSOp<"tensor_to_stream",[NoMemoryEffect]> {
}];

let hasVerifier = 1;
let hasFolder = 1;
}

def StreamToTensorOp : HLSOp<"stream_to_tensor", [NoMemoryEffect]> {
Expand Down Expand Up @@ -190,6 +195,54 @@ def StreamWriteOp : HLSOp<"stream_write", [
let hasVerifier = 1;
}

def StreamExpandShapeOp : HLSOp<"stream_expand_shape", [NoMemoryEffect]> {
let summary = [{}];

let arguments = (ins AnyStream:$input, IndexListArrayAttr:$reassociation);
let results = (outs AnyStream:$output);
let assemblyFormat = [{
$input attr-dict `:` functional-type($input, $output)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
SmallVector<ReassociationIndices, 4> reassociationIndices;
for (auto attr : getReassociation())
reassociationIndices.push_back(llvm::to_vector<2>(
llvm::map_range(::llvm::cast<ArrayAttr>(attr), [&](Attribute indexAttr) {
return ::llvm::cast<IntegerAttr>(indexAttr).getInt();
})));
return reassociationIndices;
}
}];
}

def StreamCollapseShapeOp : HLSOp<"stream_collapse_shape", [NoMemoryEffect]> {
let summary = [{}];

let arguments = (ins AnyStream:$input, IndexListArrayAttr:$reassociation);
let results = (outs AnyStream:$output);
let assemblyFormat = [{
$input attr-dict `:` functional-type($input, $output)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
SmallVector<ReassociationIndices, 4> reassociationIndices;
for (auto attr : getReassociation())
reassociationIndices.push_back(llvm::to_vector<2>(
llvm::map_range(::llvm::cast<ArrayAttr>(attr), [&](Attribute indexAttr) {
return ::llvm::cast<IntegerAttr>(indexAttr).getInt();
})));
return reassociationIndices;
}
}];
}

def StreamElementChunkOp : HLSOp<"stream_element_chunk", [NoMemoryEffect]> {
let summary = [{
Chunk the input stream channel's element to smaller elements and pass to the
Expand Down Expand Up @@ -224,8 +277,7 @@ def StreamElementConcatOp : HLSOp<"stream_element_concat", [NoMemoryEffect]> {

def StreamCastOp : HLSOp<"stream_cast", [NoMemoryEffect]> {
let summary = [{
Cast a stream channel to another type. The input and output stream channel
must be compatible with each other.
Cast a stream channel to another type.
}];

let arguments = (ins AnyStream:$input);
Expand Down
12 changes: 10 additions & 2 deletions include/scalehls/Dialect/HLS/IR/HLSTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def StreamType : HLSType<"Stream", [ShapedTypeInterface]> {
}]>
];

let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Returns if this type is ranked.
bool hasRank() const { return true; }
Expand All @@ -76,11 +78,17 @@ def StreamType : HLSType<"Stream", [ShapedTypeInterface]> {
using ShapedType::Trait<StreamType>::getNumElements;
using ShapedType::Trait<StreamType>::getDimSize;

/// Infer the integral shape of the data this stream type represents.
SmallVector<int64_t> inferIntegralShape() const;

/// Return whether the "other" stream type is compatible with this stream
/// type. By being compatible, it means that the two stream types have the
/// element type and iteration order, but not necessarily the same iteration
/// shape and layout. Compatible stream types can be casted to each other.
bool isCompatibleWith(StreamType other);
/// shape and layout.
bool isCompatibleWith(StreamType other) const;

/// Return whether this stream type can be converted to the "tensor" type.
bool isConvertableWith(RankedTensorType tensor) const;
}];
}

Expand Down
31 changes: 29 additions & 2 deletions lib/Dialect/HLS/IR/HLSOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,28 @@ LogicalResult TensorInitOp::verify() {
// TensorToStreamOp
//===----------------------------------------------------------------------===//

LogicalResult TensorToStreamOp::verify() { return success(); }
LogicalResult TensorToStreamOp::verify() {
if (getStream().getType().isConvertableWith(getTensor().getType()))
return emitOpError("stream type is not convertable with tensor type");
return success();
}

OpFoldResult TensorToStreamOp::fold(FoldAdaptor adaptor) {
if (auto streamToTensor = getTensor().getDefiningOp<StreamToTensorOp>())
if (streamToTensor.getStream().getType() == getStream().getType())
return streamToTensor.getStream();
return {};
}

//===----------------------------------------------------------------------===//
// StreamToTensorOp
//===----------------------------------------------------------------------===//

LogicalResult StreamToTensorOp::verify() { return success(); }
LogicalResult StreamToTensorOp::verify() {
if (getStream().getType().isConvertableWith(getTensor().getType()))
return emitOpError("stream type is not convertable with tensor type");
return success();
}

//===----------------------------------------------------------------------===//
// StreamOp
Expand Down Expand Up @@ -269,6 +284,18 @@ LogicalResult StreamElementChunkOp::verify() { return success(); }

LogicalResult StreamElementConcatOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// StreamExpandShapeOp
//===----------------------------------------------------------------------===//

LogicalResult StreamExpandShapeOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// StreamCollapseShapeOp
//===----------------------------------------------------------------------===//

LogicalResult StreamCollapseShapeOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// StreamCastOp
//===----------------------------------------------------------------------===//
Expand Down
58 changes: 56 additions & 2 deletions lib/Dialect/HLS/IR/HLSTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@ using namespace mlir;
using namespace scalehls;
using namespace hls;

LogicalResult
hls::StreamType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, ArrayRef<int64_t> shape,
MemRefLayoutAttrInterface iterLayout, int64_t depth) {
if (iterLayout.getAffineMap().getNumSymbols())
return emitError() << "iteration layout cannot have symbols";
if (shape.size() != iterLayout.getAffineMap().getNumDims())
return emitError() << "shape size and iteration layout mismatch";
if (auto shapedElementType = llvm::dyn_cast<ShapedType>(elementType)) {
if (!shapedElementType.hasRank())
return emitError() << "element type must be ranked";
if (iterLayout.getAffineMap().getNumResults() !=
shapedElementType.getRank())
return emitError() << "iteration layout and element type rank mismatch";
}
return success();
}

/// Clone this type with the given shape and element type. If the provided
/// shape is `std::nullopt`, the current shape of the type is used.
StreamType hls::StreamType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Expand All @@ -17,12 +35,48 @@ StreamType hls::StreamType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
getIterLayout(), getDepth());
}

/// Infer the integral shape of the data this stream type represents.
SmallVector<int64_t> hls::StreamType::inferIntegralShape() const {
SmallVector<AffineExpr> iterSizeInputs;
for (auto iterSize : getShape())
iterSizeInputs.push_back(getAffineConstantExpr(iterSize, getContext()));
auto iterSizeMap = getIterLayout().getAffineMap().replaceDimsAndSymbols(
iterSizeInputs, {}, 0, 0);
auto shapedElementType = llvm::dyn_cast<ShapedType>(getElementType());

SmallVector<int64_t> integralShape;
for (auto [index, iterSize] : llvm::enumerate(iterSizeMap.getResults())) {
auto constIterSize = llvm::dyn_cast<AffineConstantExpr>(iterSize);
assert(constIterSize && "non-constant size in the iteration layout");

if (!shapedElementType)
integralShape.push_back(constIterSize.getValue());
else
integralShape.push_back(constIterSize.getValue() *
shapedElementType.getDimSize(index));
}
return integralShape;
}

/// Return whether the "other" stream type is compatible with this stream type.
/// By being compatible, it means that the two stream types have the element
/// type and iteration order, but not necessarily the same iteration shape and
/// layout. Compatible stream types can be casted to each other.
bool hls::StreamType::isCompatibleWith(StreamType other) {
/// layout.
bool hls::StreamType::isCompatibleWith(StreamType other) const {
if (*this == other)
return true;
return false;
}

/// Return whether this stream type can be converted to the "tensor" type.
bool hls::StreamType::isConvertableWith(RankedTensorType tensor) const {
if (tensor.getElementType() != getElementType())
return false;
if (tensor.getRank() != getRank())
return false;
if (llvm::any_of(llvm::zip(tensor.getShape(), getShape()), [](auto pair) {
return std::get<0>(pair) != std::get<1>(pair);
}))
return false;
return true;
}

0 comments on commit 7040069

Please sign in to comment.