Skip to content

Commit

Permalink
Revert "[Layouts] Propagate layouts into conditionals (#5610)" (#5710)
Browse files Browse the repository at this point in the history
Reverting due to regressions in internal tests
  • Loading branch information
pawelszczerbuk authored Jan 27, 2025
1 parent 19fe7cb commit 98b40d5
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 238 deletions.
116 changes: 0 additions & 116 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ class LayoutRematerialization {
void backwardRematerialization(ConvertLayoutOp convertOp);
void hoistConvertOnTopOfExtOrBroadcast();
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
void hoistConvertIntoConditionals();
void hoistConvertIntoConditionals(ConvertLayoutOp convertOp);
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp, IRMapping &mapping);
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
Expand Down Expand Up @@ -1022,22 +1020,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
}
}

void LayoutRematerialization::hoistConvertIntoConditionals() {
// Go through each ConvertLayoutOp.
SmallVector<ConvertLayoutOp> convertOps;
funcOp.walk(
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
hoistConvertIntoConditionals(convertOp);
if (!opToDelete.contains(convertOp)) {
// If the conversion didn't get removed, consider it for reuse in future
// backward slices.
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}

void LayoutRematerialization::backwardRematerialization(
ConvertLayoutOp convertOp) {
// we don't handle conversions to DotOperandEncodingAttr
Expand Down Expand Up @@ -1169,100 +1151,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
rewriteSlice(slice, layout, convertOp, mapping);
}

void LayoutRematerialization::hoistConvertIntoConditionals(
ConvertLayoutOp convertOp) {
// Take the backward slice of tensor dependencies, stopping at conditionals.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
if (failed(getRematerializableSlice(convertOp.getSrcMutable(),
convertOp.getType().getEncoding(), slice,
layout, isIfOp)))
return;

// Find conditional edges above which the conversion can be hoisted.
SmallVector<std::pair<Value, OpOperand *>> hoistAbove;
unsigned sliceSize = slice.size();
// The routine will recurse through backward slices, e.g. to handle loops and
// conditional chains. Thus, we re-query the size of `slice`.
for (unsigned i = 0; i < slice.size(); i++) {
Value v = slice[i];
auto ifOp = v.getDefiningOp<scf::IfOp>();
if (!ifOp)
continue;

Attribute rootLayout = layout.at(v);
unsigned resIdx = cast<OpResult>(v).getResultNumber();

// Take the backward slice along each branch.
auto thenYield =
cast<scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
auto elseYield =
cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());

OpOperand &thenRes = thenYield.getResultsMutable()[resIdx];
OpOperand &elseRes = elseYield.getResultsMutable()[resIdx];

SetVector<Value> thenSlice, elseSlice;
DenseMap<Value, Attribute> thenLayout, elseLayout;

LogicalResult thenResult = getRematerializableSlice(
thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
LogicalResult elseResult = getRematerializableSlice(
elseRes, rootLayout, elseSlice, elseLayout, isIfOp);

// If propagation across both edges of this conditional succeeded, then we
// don't need to hoist across it.
if (succeeded(thenResult) && succeeded(elseResult)) {
slice.insert(thenSlice.begin(), thenSlice.end());
slice.insert(elseSlice.begin(), elseSlice.end());
layout.insert(thenLayout.begin(), thenLayout.end());
layout.insert(elseLayout.begin(), elseLayout.end());
continue;
}

// If propagation across both edges failed, then there is nothing to do
// for this one.
if (failed(thenResult) && failed(elseResult))
continue;

// The layout conversion can be rematerialized along one edge but not the
// other. We can hoist the conversion into the other branch.
if (succeeded(elseResult)) {
std::swap(thenSlice, elseSlice);
std::swap(thenLayout, elseLayout);
hoistAbove.push_back({v, &thenRes});
} else {
hoistAbove.push_back({v, &elseRes});
}
slice.insert(thenSlice.begin(), thenSlice.end());
layout.insert(thenLayout.begin(), thenLayout.end());
}

// It's hard to know if duplicating the conversion into separate branches is
// profitable without more analysis. For now, hoist at most one.
if (hoistAbove.size() != 1)
return;

IRMapping mapping;
for (auto [result, edge] : hoistAbove) {
// Hoist the convert into the conditional and rewrite the slice.
OpBuilder b(edge->getOwner());
Value v = edge->get();
Attribute encoding = layout.at(result);

auto tensorType = cast<RankedTensorType>(v.getType());
auto newType = RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);

Value newCvt = b.create<ConvertLayoutOp>(convertOp.getLoc(), newType, v);

mapping.map(v, newCvt);
slice.remove(v);
}
rewriteSlice(slice, layout, convertOp, mapping);
}

void backwardRematerialization(ModuleOp module) {
module.walk([](FuncOp funcOp) {
LayoutRematerialization layoutRemat(funcOp);
Expand All @@ -1277,10 +1165,6 @@ void hoistConvert(ModuleOp module) {
LayoutRematerialization layoutRemat(funcOp);
layoutRemat.hoistConvertOnTopOfExtOrBroadcast();
layoutRemat.cleanup();

layoutRemat = LayoutRematerialization(funcOp);
layoutRemat.hoistConvertIntoConditionals();
layoutRemat.cleanup();
});
}
} // namespace
Expand Down
11 changes: 5 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,11 @@ LogicalResult getConvertBackwardSlice(
auto updateLayout = [&](Value value, Attribute encoding) {
assert((isa<RankedTensorType>(value.getType())));
slice.insert(value);
Attribute &existing = layout[value];
if (existing && existing != encoding)
return failure();
existing = encoding;
if (layout.find(value) != layout.end()) {
if (layout[value] != encoding)
return failure();
}
layout[value] = encoding;
return success();
};

Expand All @@ -822,8 +823,6 @@ LogicalResult getConvertBackwardSlice(
}

if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
if (stopPropagation && stopPropagation(ifOp))
continue;
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();

OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx);
Expand Down
116 changes: 0 additions & 116 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2828,122 +2828,6 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) {

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {

// CHECK-LABEL: @hoist_one_conditional
tt.func @hoist_one_conditional(
%arg0: i1,
%arg1: tensor<128x32x!tt.ptr<f32>, #blocked>,
%arg2: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>,
%arg3: tensor<128x128xf32, #mma>
) -> tensor<128x128xf32, #mma> {

// CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked>
// CHECK: scf.if
%0 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) {
// CHECK-NEXT: [[RES:%.*]] = tt.load
%3 = tt.load %arg1 : tensor<128x32x!tt.ptr<f32>, #blocked>
// CHECK-NEXT: ttg.convert_layout [[RES]]
// CHECK-NEXT: yield
scf.yield %3 : tensor<128x32xf32, #blocked>
} else {
scf.yield %cst : tensor<128x32xf32, #blocked>
}
// CHECK-NOT: ttg.convert_layout
%1 = ttg.convert_layout %0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%2 = tt.dot %1, %arg2, %arg3 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
tt.return %2 : tensor<128x128xf32, #mma>
}

// CHECK-LABEL: @hoist_multiple_conditional
tt.func @hoist_multiple_conditional(
%arg0: i1,
%arg1: i1,
%arg2: tensor<128x32x!tt.ptr<f32>, #blocked>,
%arg3: tensor<128x32x!tt.ptr<f32>, #blocked>,
%arg4: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>,
%arg5: tensor<128x128xf32, #mma>
) -> tensor<128x128xf32, #mma> {
// CHECK-COUNT-1: ttg.convert_layout
%cst0 = arith.constant dense<1.0> : tensor<128x32xf32, #blocked>
%cst1 = arith.constant dense<2.0> : tensor<128x32xf32, #blocked>
%0 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) {
%3 = tt.load %arg2 : tensor<128x32x!tt.ptr<f32>, #blocked>
scf.yield %3 : tensor<128x32xf32, #blocked>
} else {
scf.yield %cst0 : tensor<128x32xf32, #blocked>
}
%1 = scf.if %arg1 -> (tensor<128x32xf32, #blocked>) {
%4 = tt.load %arg3 : tensor<128x32x!tt.ptr<f32>, #blocked>
scf.yield %4 : tensor<128x32xf32, #blocked>
} else {
scf.yield %cst1 : tensor<128x32xf32, #blocked>
}
%2 = arith.addf %0, %1 : tensor<128x32xf32, #blocked>
%3 = ttg.convert_layout %2 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%4 = tt.dot %3, %arg4, %arg5 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
tt.return %4 : tensor<128x128xf32, #mma>
}

// CHECK-LABEL: @hoist_across_loop
tt.func @hoist_across_loop(
%arg0: i1,
%arg1: tensor<128x32x!tt.ptr<f32>, #blocked>,
%arg2: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>,
%arg3: tensor<128x128xf32, #mma>
) -> tensor<128x128xf32, #mma> {
// CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op
%cst = arith.constant dense<1.0> : tensor<128x32xf32, #blocked>
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c32_i32 = arith.constant 32 : i32
// CHECK: scf.for
%0:2 = scf.for %i = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg4 = %cst, %acc = %arg3) -> (tensor<128x32xf32, #blocked>, tensor<128x128xf32, #mma>) : i32 {
// CHECK-NEXT: scf.if
%1 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) {
// CHECK-NEXT: [[RES:%.*]] = tt.load
// CHECK-NEXT: ttg.convert_layout [[RES]]
%3 = tt.load %arg1 : tensor<128x32x!tt.ptr<f32>, #blocked>
scf.yield %3 : tensor<128x32xf32, #blocked>
} else {
scf.yield %arg4 : tensor<128x32xf32, #blocked>
}
// CHECK-NOT: ttg.convert_layout
%2 = ttg.convert_layout %1 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%3 = tt.dot %2, %arg2, %acc : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
scf.yield %1, %3 : tensor<128x32xf32, #blocked>, tensor<128x128xf32, #mma>
}
tt.return %0#1 : tensor<128x128xf32, #mma>
}

// CHECK-LABEL: @chained_if
tt.func @chained_if(%arg0: i1, %arg1: i1, %arg2: tensor<32x32x!tt.ptr<f32>, #blocked>, %arg3: tensor<32x32x!tt.ptr<f32>, #blocked>) -> tensor<32x32xf32, #mma> {
// CHECK-COUNT-1: ttg.convert_layout
%cst = arith.constant dense<1.0> : tensor<32x32xf32, #blocked>
%0 = scf.if %arg0 -> tensor<32x32xf32, #blocked> {
%anchor = tt.load %arg2 : tensor<32x32x!tt.ptr<f32>, #blocked>
scf.yield %anchor : tensor<32x32xf32, #blocked>
} else {
scf.yield %cst : tensor<32x32xf32, #blocked>
}
%1 = scf.if %arg1 -> tensor<32x32xf32, #blocked> {
%anchor = tt.load %arg3 : tensor<32x32x!tt.ptr<f32>, #blocked>
scf.yield %anchor : tensor<32x32xf32, #blocked>
} else {
scf.yield %0 : tensor<32x32xf32, #blocked>
}
%2 = ttg.convert_layout %1 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #mma>
tt.return %2 : tensor<32x32xf32, #mma>
}

}

// -----

#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}>
#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}>

Expand Down

0 comments on commit 98b40d5

Please sign in to comment.