Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][BACKEND] Disable pingpong with non-local_load input. #5718

Merged
merged 6 commits into from
Jan 29, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[AMD][BACKEND] Disable pingpong with non-local_load input.
Pingpong pass only expects to handle local_load ops as A/B
Avoid using the trasform when different op is detected.
jungpark-mlir committed Jan 27, 2025
commit cf15bc7b6422ffc7d4226be2fb1f598fe12b455c
60 changes: 32 additions & 28 deletions third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp
Original file line number Diff line number Diff line change
@@ -151,35 +151,39 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v,
int64_t sliceWidth) {
SmallVector<Operation *> slices;
SmallVector<Operation *> subviews;
auto memDesc = v.getDefiningOp()->getOperand(0);
auto type = cast<ttg::MemDescType>(memDesc.getType());
SmallVector<int64_t> shape = llvm::to_vector(type.getShape());
Type elementType = type.getElementType();
int64_t kIdx = opIdx == 0 ? 1 : 0;
shape[kIdx] = sliceWidth;
// Each slice cannot be smaller than the smallest supported mfma width.
if (sliceWidth < 16)
return failure();
auto dotOperandEnc = ttg::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, kWidth);
auto subviewDescType = ttg::MemDescType::get(
shape, elementType, type.getEncoding(), type.getMemorySpace(),
type.getMutableMemory(), type.getAllocShape());
for (int i = 0; i < numSlices; i++) {
SmallVector<Value> offsetsVal;
SmallVector<int64_t> offsets = {0, 0};
offsets[kIdx] = i;
for (int64_t off : offsets) {
offsetsVal.push_back(constOffsets[off]);
// TODO: support transformed input to dot
antiagainst marked this conversation as resolved.
Show resolved Hide resolved
if (auto maybeLocal = v.getDefiningOp<ttg::LocalLoadOp>()) {
antiagainst marked this conversation as resolved.
Show resolved Hide resolved
auto memDesc = maybeLocal.getSrc();
auto type = cast<ttg::MemDescType>(memDesc.getType());
SmallVector<int64_t> shape = llvm::to_vector(type.getShape());
Type elementType = type.getElementType();
int64_t kIdx = opIdx == 0 ? 1 : 0;
shape[kIdx] = sliceWidth;
// Each slice cannot be smaller than the smallest supported mfma width.
if (sliceWidth < 16)
return failure();
auto dotOperandEnc = ttg::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, kWidth);
auto subviewDescType = ttg::MemDescType::get(
shape, elementType, type.getEncoding(), type.getMemorySpace(),
type.getMutableMemory(), type.getAllocShape());
for (int i = 0; i < numSlices; i++) {
SmallVector<Value> offsetsVal;
SmallVector<int64_t> offsets = {0, 0};
offsets[kIdx] = i;
for (int64_t off : offsets) {
offsetsVal.push_back(constOffsets[off]);
}
Value newSmem = builder.create<ttg::MemDescSubviewOp>(
v.getLoc(), subviewDescType, memDesc, offsetsVal);
Value prefetchSlice = builder.create<ttg::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
subviews.push_back(newSmem.getDefiningOp());
slices.push_back(prefetchSlice.getDefiningOp());
}
Value newSmem = builder.create<ttg::MemDescSubviewOp>(
v.getLoc(), subviewDescType, memDesc, offsetsVal);
Value prefetchSlice = builder.create<ttg::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
subviews.push_back(newSmem.getDefiningOp());
slices.push_back(prefetchSlice.getDefiningOp());
}
} else
return failure();
subViewOps.push_back(subviews);
loadSliceOps.push_back(slices);
return success();