Skip to content

Commit c76b342

Browse files
authored
[BACKEND] Add missing precondition in optimize acc init (#5184)
We need scalar select to be able to do this optimization.
1 parent 1fc3269 commit c76b342

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ std::optional<std::pair<Operation *, int>> findZeroInitOp(Value accUse,
6565
return std::nullopt;
6666
}
6767
if (auto selOp = dyn_cast<arith::SelectOp>(defOp)) {
68+
if (!selOp.getCondition().getType().isInteger(1))
69+
return std::nullopt;
6870
if (isConstantZeroTensor(selOp.getTrueValue()) ||
6971
isConstantZeroTensor(selOp.getFalseValue())) {
7072
return std::make_pair(selOp, 0);

test/TritonGPU/accumulator-init.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,4 +348,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
348348
}
349349
tt.return %17 : tensor<128x16xf32, #mma1>
350350
}
351+
352+
// If the condition is a tensor skip the optimization.
353+
// CHECK-LABEL: @negative_sel_tensor
354+
// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc
355+
tt.func @negative_sel_tensor(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> {
356+
%c0_i32 = arith.constant 0 : i32
357+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
358+
%c1_i32 = arith.constant 1 : i32
359+
%c8_i32 = arith.constant 8 : i32
360+
%17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 {
361+
%acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1>
362+
%acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1>
363+
scf.yield %acc: tensor<128x16xf32, #mma1>
364+
}
365+
tt.return %17 : tensor<128x16xf32, #mma1>
366+
}
351367
}

0 commit comments

Comments
 (0)