diff --git a/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp b/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp index 02295d5a..547ef849 100644 --- a/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp +++ b/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp @@ -20,20 +20,20 @@ hls::TaskOp wrapOpIntoTask(Operation *op, StringRef taskName, builder.setInsertionPoint(op); auto task = builder.create(op->getLoc(), destTypes, destOperands); + op->replaceAllUsesWith(task.getResults()); task->setAttr(taskName, builder.getUnitAttr()); auto taskBlock = builder.createBlock( &task.getBody(), task.getBody().end(), destTypes, llvm::map_to_vector(destOperands, [&](Value v) { return v.getLoc(); })); - IRMapping mapper; + + builder.setInsertionPointToEnd(taskBlock); + auto yieldOp = builder.create(op->getLoc(), op->getResults()); + + op->moveBefore(yieldOp); for (auto [destOperand, taskBlockArg] : llvm::zip(destOperands, taskBlock->getArguments())) - mapper.map(destOperand, taskBlockArg); - - builder.setInsertionPointToStart(taskBlock); - auto newOp = builder.clone(*op, mapper); - builder.create(op->getLoc(), newOp->getResults()); - op->replaceAllUsesWith(task.getResults()); - op->erase(); + destOperand.replaceUsesWithIf( + taskBlockArg, [&](OpOperand &use) { return use.getOwner() == op; }); return task; } @@ -54,6 +54,8 @@ static LogicalResult scheduleBlock(StringRef prefix, Block *block, destOperands = loop.getInitArgs(); } else if (isa(op)) { + // TODO: For now, tensor insert ops are not scheduled into separate tasks + // as they will be handled in the bufferization passes. continue; } else if (auto destStyleOp = dyn_cast(op)) { destOperands = destStyleOp.getDpsInits();