Skip to content

Commit

Permalink
Replace itensor/tensor_init op with itensor/tensor_instance op after …
Browse files Browse the repository at this point in the history
…dataflow scheduling
  • Loading branch information
hanchenye committed Mar 30, 2024
1 parent 245fd0f commit db9ce64
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 60 deletions.
26 changes: 16 additions & 10 deletions include/scalehls/Dialect/HLS/IR/HLSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,18 @@ def StreamOp : HLSOp<"stream", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Instantiate a stream channel";

let arguments = (ins OptionalAttr<StrAttr>:$location);
let results = (outs AnyStream:$stream);
let assemblyFormat = "attr-dict `:` type(results)";
let assemblyFormat = [{
(`location` $location^)? attr-dict `:` type(results)
}];

let hasVerifier = 1;
let hasCanonicalizeMethod = 1;
let builders = [
OpBuilder<(ins "mlir::Type":$stream),
"build($_builder, $_state, stream, nullptr);">
];

let extraClassDeclaration = [{
SmallVector<OpOperand *> getReadUses();
Expand Down Expand Up @@ -411,7 +418,7 @@ def BufferOp : HLSOp<"buffer", [
OptionalAttr<StrAttr>:$location);
let results = (outs AnyBuffer:$memref);
let assemblyFormat = [{
(`init` $initValue^)? attr-dict `:` type(results)
(`init` $initValue^)? (`location` $location^)? attr-dict `:` type(results)
}];

let hasVerifier = 1;
Expand All @@ -435,6 +442,13 @@ def TaskOp : HLSOp<"task", [SingleBlockImplicitTerminator<"YieldOp">]> {
OptionalAttr<StrAttr>:$location);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = [{
$name (`location` $location^)? (`inits` $inits^)? $body attr-dict `:`
functional-type(operands, results)
}];

let hasVerifier = 1;
let hasCanonicalizer = 1;
let builders = [
OpBuilder<(ins "mlir::ValueRange":$inits, "mlir::StringAttr":$name,
"mlir::StringAttr":$location),
Expand All @@ -445,14 +459,6 @@ def TaskOp : HLSOp<"task", [SingleBlockImplicitTerminator<"YieldOp">]> {
"build($_builder, $_state, inits, nullptr);">
];

let assemblyFormat = [{
$name (`location` $location^)? (`inits` $inits^)? $body attr-dict `:`
functional-type(operands, results)
}];

let hasVerifier = 1;
let hasCanonicalizer = 1;

let extraClassDeclaration = [{
TypeRange getInitTypes() { return TypeRange(getInits()); }

Expand Down
47 changes: 17 additions & 30 deletions lib/Dialect/HLS/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,9 @@ struct YieldOpInterface
}
};

struct TensorInitOpInterface
: public BufferizableOpInterface::ExternalModel<TensorInitOpInterface,
hls::TensorInitOp> {
struct TensorInstanceOpInterface
: public BufferizableOpInterface::ExternalModel<TensorInstanceOpInterface,
hls::TensorInstanceOp> {
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }

bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
Expand All @@ -462,53 +462,40 @@ struct TensorInitOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto tensorInit = cast<hls::TensorInitOp>(op);
auto tensorInst = cast<hls::TensorInstanceOp>(op);

// Nothing to do for dead TensorInitOps.
if (tensorInit->getUses().empty()) {
rewriter.eraseOp(tensorInit);
// Nothing to do for dead TensorInstanceOps.
if (tensorInst->getUses().empty()) {
rewriter.eraseOp(tensorInst);
return success();
}

// Create memory allocation.
auto maybeType =
bufferization::getBufferType(tensorInit.getResult(), options);
bufferization::getBufferType(tensorInst.getResult(), options);
if (failed(maybeType))
return failure();

for (auto &use : llvm::make_early_inc_range(tensorInit->getUses())) {
rewriter.setInsertionPoint(use.getOwner());
FailureOr<Value> buffer = options.createAlloc(
rewriter, tensorInit.getLoc(), maybeType->cast<MemRefType>(), {});
if (failed(buffer))
return failure();

// Handle initial value.
auto bufferOp = buffer->getDefiningOp<BufferOp>();
bufferOp.setInitValueAttr(tensorInit.getInitValueAttr());

auto repl = rewriter.create<bufferization::ToTensorOp>(
tensorInit.getLoc(), *buffer);
rewriter.updateRootInPlace(use.getOwner(), [&]() { use.set(repl); });
}
rewriter.eraseOp(tensorInit);
auto buffer = rewriter.create<hls::BufferOp>(
tensorInst.getLoc(), *maybeType, tensorInst.getInitValueAttr(),
tensorInst.getLocationAttr());
replaceOpWithBufferizedValues(rewriter, op, buffer.getResult());
return success();
}

FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto tensorInit = cast<hls::TensorInitOp>(op);
assert(value == tensorInit.getResult() && "invalid value");
auto tensorInst = cast<hls::TensorInstanceOp>(op);
assert(value == tensorInst.getResult() && "invalid value");

// Compute memory space of this allocation.
Attribute memorySpace;
if (options.defaultMemorySpace.has_value())
memorySpace = *options.defaultMemorySpace;
else
return tensorInit.emitError("could not infer memory space");
return tensorInst.emitError("could not infer memory space");

return getMemRefTypeWithStaticIdentityLayout(tensorInit.getType(),
return getMemRefTypeWithStaticIdentityLayout(tensorInst.getType(),
memorySpace);
}
};
Expand All @@ -518,6 +505,6 @@ void mlir::scalehls::hls::registerBufferizableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, HLSDialect *dialect) {
hls::TaskOp::attachInterface<TaskOpInterface>(*ctx);
hls::YieldOp::attachInterface<YieldOpInterface>(*ctx);
hls::TensorInitOp::attachInterface<TensorInitOpInterface>(*ctx);
hls::TensorInstanceOp::attachInterface<TensorInstanceOpInterface>(*ctx);
});
}
26 changes: 18 additions & 8 deletions lib/Dialect/HLS/Transforms/LowerITensorToStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@ static StreamType getStreamType(hls::ITensorType iTensorType) {
iTensorType.getDepth());
}

namespace {
struct LowerITensorInstanceOp
: public OpRewritePattern<hls::ITensorInstanceOp> {
using OpRewritePattern<hls::ITensorInstanceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(hls::ITensorInstanceOp inst,
PatternRewriter &rewriter) const override {
auto streamType = getStreamType(inst.getType());
auto stream = rewriter.create<hls::StreamOp>(inst.getLoc(), streamType,
inst.getLocationAttr());
rewriter.replaceOpWithNewOp<hls::StreamToITensorOp>(inst, inst.getType(),
stream);
return success();
}
};
} // namespace

namespace {
struct LowerITensorReadOp : public OpRewritePattern<hls::ITensorReadOp> {
using OpRewritePattern<hls::ITensorReadOp>::OpRewritePattern;
Expand Down Expand Up @@ -144,14 +161,6 @@ struct RemoveITensorToStreamOp
if (toStream.use_empty()) {
rewriter.eraseOp(toStream);
return success();
}

if (auto init = toStream.getITensor().getDefiningOp<hls::ITensorInitOp>()) {
if (!init.getInitValue()) {
rewriter.replaceOpWithNewOp<hls::StreamOp>(toStream,
toStream.getStreamType());
return success();
}
} else if (auto toITensor = toStream.getITensor()
.getDefiningOp<hls::StreamToITensorOp>()) {
if (toStream.getStreamType() == toITensor.getStreamType()) {
Expand Down Expand Up @@ -217,6 +226,7 @@ struct LowerITensorToStream

// Apply lowering patterns.
mlir::RewritePatternSet patterns(context);
patterns.add<LowerITensorInstanceOp>(context);
patterns.add<LowerITensorReadOp>(context);
patterns.add<LowerITensorWriteOp>(context);
patterns.add<LowerITensorViewLikeOpInterface>(context);
Expand Down
21 changes: 11 additions & 10 deletions lib/Dialect/HLS/Transforms/ScalarizeITensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,23 @@ static ITensorType getScalarITensorType(ITensorType iTensor) {
}

namespace {
struct ScalarizeITensorInitOp : public OpRewritePattern<hls::ITensorInitOp> {
using OpRewritePattern<hls::ITensorInitOp>::OpRewritePattern;
struct ScalarizeITensorInstanceOp
: public OpRewritePattern<hls::ITensorInstanceOp> {
using OpRewritePattern<hls::ITensorInstanceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(hls::ITensorInitOp init,
LogicalResult matchAndRewrite(hls::ITensorInstanceOp inst,
PatternRewriter &rewriter) const override {
auto iTensorType = init.getType();
auto iTensorType = inst.getType();
if (!iTensorType.hasShapedElementType())
return failure();

rewriter.updateRootInPlace(init, [&]() {
init.getResult().setType(getScalarITensorType(iTensorType));
rewriter.updateRootInPlace(inst, [&]() {
inst.getResult().setType(getScalarITensorType(iTensorType));
});
rewriter.setInsertionPointAfter(init);
rewriter.setInsertionPointAfter(inst);
auto cast =
rewriter.create<hls::ITensorCastOp>(init.getLoc(), iTensorType, init);
rewriter.replaceAllUsesExcept(init, cast.getResult(), cast);
rewriter.create<hls::ITensorCastOp>(inst.getLoc(), iTensorType, inst);
rewriter.replaceAllUsesExcept(inst, cast.getResult(), cast);
return success();
}
};
Expand Down Expand Up @@ -279,7 +280,7 @@ struct ScalarizeITensor

// Apply scalarization patterns.
mlir::RewritePatternSet patterns(context);
patterns.add<ScalarizeITensorInitOp>(context);
patterns.add<ScalarizeITensorInstanceOp>(context);
patterns.add<ScalarizeITensorReadOp>(context);
patterns.add<ScalarizeITensorWriteOp>(context);
patterns.add<ScalarizeITensorReassociateOp>(context);
Expand Down
30 changes: 28 additions & 2 deletions lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,19 @@ struct ScheduleDataflow
}
}

void applyDefaultInstanceLocations() {
SmallVector<Operation *> instances;
getOperation().walk([&](Operation *op) {
if (isa<TensorInstanceOp, ITensorInstanceOp>(op))
instances.push_back(op);
});
for (auto instance : instances) {
assert(llvm::hasSingleElement(instance->getUsers()) &&
"instance should have a single user");
instance->moveBefore(*instance->user_begin());
}
}

// Recursively add the task to the group, and create a new group if the task
// is not in the group.
void dfsScheduleDefiningOp(Value value, size_t prevLevel) {
Expand All @@ -142,8 +155,8 @@ struct ScheduleDataflow
opToLevelMap.lookup(definingOp) > prevLevel)
return;

assert(!isa<hls::TensorInitOp>(definingOp) &&
!isa<hls::ITensorInitOp>(definingOp) &&
assert(!isa<hls::TensorInstanceOp>(definingOp) &&
!isa<hls::ITensorInstanceOp>(definingOp) &&
"tensor/itensor init op should not be scheduled at all");

if (auto task = dyn_cast<hls::TaskOp>(definingOp)) {
Expand Down Expand Up @@ -177,6 +190,17 @@ struct ScheduleDataflow
auto func = getOperation();
OpBuilder builder(&getContext());

// Check if all itensor/tensor init ops have been converted.
auto checkResult = func.walk([&](Operation *op) {
if (isa<hls::TensorInitOp, hls::ITensorInitOp>(op)) {
op->emitOpError("tensor/itensor init op should have been converted");
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (checkResult.wasInterrupted())
return signalPassFailure();

// If the task locations are not set, apply the default task locations.
if (!checkTaskLocations())
applyDefaultTaskLocations();
Expand Down Expand Up @@ -210,6 +234,8 @@ struct ScheduleDataflow
func.getName().str() + "_schedule_" + std::to_string(taskId++);
wrapOpsIntoTask(ops, taskName, location, builder);
}

applyDefaultInstanceLocations();
}

private:
Expand Down

0 comments on commit db9ce64

Please sign in to comment.