Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Jan 10, 2025
1 parent ca0782d commit d0fc178
Showing 1 changed file with 0 additions and 143 deletions.
143 changes: 0 additions & 143 deletions lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,42 +126,6 @@ Value getAccumLoopCountArg(scf::ForOp parentForOp) {
return tmpAccumLoopCount;
}

// At this stage, part of the region is specialized. The input is the new ifOp
// for the taskId we are currently working on. The thenBlock has been
// specialized.
Value getAccumLoopCount(scf::IfOp ifOp, Value tmpAccumLoopCount) {
// FIXME: last argument of ForOp is bufferIdx which is a modulo number.
// We can replace it with the non-modulo number.
// For now, trace the use of the last argument for the outer loop.
Operation *lastOp = nullptr;
auto users = tmpAccumLoopCount.getUsers();
Operation *lastUser = nullptr;
// FIXME: Look at thenBlock only for now. Find the last Op in thenBlock
// that is is a user of tmpAccumLoopCount.
for (Operation &thenOp : ifOp.thenBlock()->getOperations()) {
if (auto lastFor = dyn_cast<scf::ForOp>(thenOp)) {
lastOp = lastFor.getOperation();
}
if (auto lastIf = dyn_cast<scf::IfOp>(thenOp)) {
lastOp = lastIf.getOperation();
}
for (auto *user : users)
if (user == &thenOp)
lastUser = user;
}
// lastOp is specialized already.
assert(lastOp != nullptr);
if (auto lastFor = dyn_cast<scf::ForOp>(lastOp)) {
// Get the last user of tmpAccumLoopCount inside this ifOp.
assert(lastUser);
return lastUser->getResult(0);
}
auto lastIf = cast<scf::IfOp>(lastOp);
assert(ifOp.getNumResults() >= 1);
auto numRes = ifOp.getNumResults();
return ifOp.getResult(numRes - 1);
}

// Return true if the IfOp contains a ForOp that is in loopWithBufferReuse.
static bool
needAccumulatedLoopCnt(scf::IfOp ifOp,
Expand Down Expand Up @@ -1189,79 +1153,6 @@ static unsigned getNumChannelsInLoop(scf::ForOp forOp,
return channelsInLoop.size();
}

#if 0
// After specializing regions, fix up the logic here.
// Return the update prevAccum.
Value updateAccumLoopCount(Operation *parentOp, Value tmpAccumLoopCount,
Value prevAccum) {
// parentOp must be IfOp or ForOp.
auto users = tmpAccumLoopCount.getUsers();
DenseSet<Operation *> userSet;
for (auto user : users)
userSet.insert(user);
SmallVector<Operation *> opList;
if (auto forOp = dyn_cast<scf::ForOp>(parentOp))
for (Operation &op : forOp.getBody()->without_terminator())
opList.push_back(&op);
if (auto ifOp = dyn_cast<scf::IfOp>(parentOp))
for (Operation &op : ifOp.thenBlock()->getOperations())
opList.push_back(&op);

// Go through body of parentOp. Update users of tmpAccumLoopCount on the way.
for (Operation *op : opList) {
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
continue;
}
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
bool needAccum = false;
ifOp.walk<WalkOrder::PreOrder>([&](Operation *subOp) {
if (auto forOp = dyn_cast<scf::ForOp>(subOp))
needAccum = true;
});
if (!needAccum)
continue;
// go inside ifOp->thenBlock to fix up.
updateAccumLoopCount(op, tmpAccumLoopCount, prevAccum);
// Update yield operand.
if (prevAccum != tmpAccumLoopCount) {
ifOp.elseYield()->replaceUsesOfWith(tmpAccumLoopCount, prevAccum);
LLVM_DEBUG({
LDBG("replace elseYield");
prevAccum.dump();
});
} else {
LDBG("no need to replace elseYield");
}
assert(ifOp.getNumResults() >= 1);
auto numRes = ifOp.getNumResults();
LDBG("update prevAccum with result from IfOp");
prevAccum = ifOp.getResult(numRes - 1); // last argument of this ifOp
continue;
}
if (userSet.count(op)) {
if (prevAccum != tmpAccumLoopCount) {
LLVM_DEBUG({
LDBG("replace use of tmpAccumLoopCount: ");
op->dump();
prevAccum.dump();
});
op->replaceUsesOfWith(tmpAccumLoopCount, prevAccum);
} else {
LDBG("no need to replace prevAccum == tmpAccumLoopCount");
}
if (isa<arith::AddIOp>(op)) {
prevAccum = op->getResult(0);
LLVM_DEBUG({
LDBG("update prevAccum: ");
prevAccum.dump();
});
}
}
}
return prevAccum;
}
#endif

bool reuseBuffers(SmallVector<Operation *> &taskTopOps,
const SmallVector<Channel *> &channels,
DenseMap<Channel *, Channel *> &mapToRepresenting,
Expand Down Expand Up @@ -2344,40 +2235,6 @@ class TritonGPUWSCodePartitionPass
LDBG("\n\nwith SpecializeRegion");
funcOp.dump();
});

#if 0
if (loopsWithAccumLoopCount >= 1) {
// Get top-level parentForOp for each taskId/region.
auto getTopForOp = [&](scf::IfOp ifOp) -> scf::ForOp {
for (Operation &op : ifOp.thenBlock()->getOperations())
if (auto forOp = dyn_cast<scf::ForOp>(&op))
return forOp;
return scf::ForOp();
};
for (auto &block : funcOp.getBody().getBlocks()) {
for (Operation &op : block.getOperations()) {
if (auto ifOp = dyn_cast<scf::IfOp>(&op)) {
scf::ForOp parentForOp = getTopForOp(ifOp);
// Get the async taskId.
auto taskIds = getAsyncTaskIds(&op);
assert(taskIds.size() == 1);
Value tmpAccumLoopCount = parentForOp
? getAccumLoopCountArg(parentForOp)
: mapForAccumLoopVar[taskIds[0]];
Operation *topOp = parentForOp ? parentForOp.getOperation() : &op;
Value prevAccum = updateAccumLoopCount(topOp, tmpAccumLoopCount,
tmpAccumLoopCount);
// handle yield for tmpAccumLoopCount.
if (parentForOp) {
auto forYield =
cast<scf::YieldOp>(parentForOp.getBody()->getTerminator());
forYield->replaceUsesOfWith(tmpAccumLoopCount, prevAccum);
}
}
}
}
}
#endif
}

void runOnOperation() override {
Expand Down

0 comments on commit d0fc178

Please sign in to comment.