Skip to content

Commit b882f6e

Browse files
authored
fix: correct broadcast_in_dim result size in dot_general_simplify (#1511)
1 parent 768f73e commit b882f6e

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8284,21 +8284,20 @@ struct DotGeneralSimplify
82848284
reduceSumInput = op.getRhs();
82858285

82868286
SmallVector<int64_t> broadcastDimsTmp(rhsType.getRank(), -1);
8287-
broadcastShape = SmallVector<int64_t>(rhsType.getRank(), -1);
82888287

82898288
int64_t idx = 0;
82908289
for (auto dim : rhsBatchingDims) {
82918290
broadcastDimsTmp[dim] = idx;
8292-
broadcastShape[idx] = rhsShape[dim];
8291+
broadcastShape.push_back(rhsShape[dim]);
82938292
idx++;
82948293
}
82958294
for (auto dim : lhsNonContractingDims) {
8296-
broadcastShape[idx] = lhsShape[dim];
8295+
broadcastShape.push_back(lhsShape[dim]);
82978296
idx++;
82988297
}
82998298
for (auto dim : rhsNonContractingDims) {
83008299
broadcastDimsTmp[dim] = idx;
8301-
broadcastShape[idx] = rhsShape[dim];
8300+
broadcastShape.push_back(rhsShape[dim]);
83028301
idx++;
83038302
}
83048303

@@ -8312,21 +8311,20 @@ struct DotGeneralSimplify
83128311
reduceSumInput = op.getLhs();
83138312

83148313
SmallVector<int64_t> broadcastDimsTmp(lhsType.getRank(), -1);
8315-
broadcastShape = SmallVector<int64_t>(lhsType.getRank(), -1);
83168314

83178315
int64_t idx = 0;
83188316
for (auto dim : lhsBatchingDims) {
83198317
broadcastDimsTmp[dim] = idx;
8320-
broadcastShape[idx] = lhsShape[dim];
8318+
broadcastShape.push_back(lhsShape[dim]);
83218319
idx++;
83228320
}
83238321
for (auto dim : lhsNonContractingDims) {
83248322
broadcastDimsTmp[dim] = idx;
8325-
broadcastShape[idx] = lhsShape[dim];
8323+
broadcastShape.push_back(lhsShape[dim]);
83268324
idx++;
83278325
}
83288326
for (auto dim : rhsNonContractingDims) {
8329-
broadcastShape[idx] = rhsShape[dim];
8327+
broadcastShape.push_back(rhsShape[dim]);
83308328
idx++;
83318329
}
83328330

test/lit_tests/dot_general_ones.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,45 @@ func.func @main3(%arg0: tensor<2x2xcomplex<f64>> {enzymexla.memory_effects = []}
6464
// CHECK-NEXT: %8 = chlo.conj %7 : tensor<2x2xcomplex<f64>> -> tensor<2x2xcomplex<f64>>
6565
// CHECK-NEXT: return %8, %4, %0 : tensor<2x2xcomplex<f64>>, tensor<2x2xcomplex<f64>>, tensor<2x2xcomplex<f64>>
6666
// CHECK-NEXT: }
67+
68+
func.func @main4(%arg0: tensor<2x16xf32> {enzymexla.memory_effects = []}, %arg1: tensor<16xf32> {enzymexla.memory_effects = []}, %arg2: tensor<16x16xf32> {enzymexla.memory_effects = []}, %arg3: tensor<16xf32> {enzymexla.memory_effects = []}, %arg4: tensor<16x1xf32> {enzymexla.memory_effects = []}, %arg5: tensor<1xf32> {enzymexla.memory_effects = []}, %arg6: tensor<2xf32> {enzymexla.memory_effects = []}) -> tensor<2xf32> attributes {enzymexla.memory_effects = []} {
69+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
70+
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
71+
%0 = stablehlo.reshape %arg4 : (tensor<16x1xf32>) -> tensor<1x16xf32>
72+
%1 = stablehlo.dot_general %arg0, %arg6, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x16xf32>, tensor<2xf32>) -> tensor<16xf32>
73+
%2 = stablehlo.add %1, %arg1 : tensor<16xf32>
74+
%3 = stablehlo.tanh %2 : tensor<16xf32>
75+
%4 = stablehlo.dot_general %arg2, %3, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16xf32>) -> tensor<16xf32>
76+
%5 = stablehlo.add %4, %arg3 : tensor<16xf32>
77+
%6 = stablehlo.dot_general %cst, %0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1xf32>, tensor<1x16xf32>) -> tensor<16xf32>
78+
%7 = stablehlo.tanh %5 : tensor<16xf32>
79+
%8 = stablehlo.multiply %7, %7 : tensor<16xf32>
80+
%9 = stablehlo.subtract %cst_0, %8 : tensor<16xf32>
81+
%10 = stablehlo.multiply %6, %9 : tensor<16xf32>
82+
%11 = stablehlo.dot_general %10, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16xf32>, tensor<16x16xf32>) -> tensor<16xf32>
83+
%12 = stablehlo.multiply %3, %3 : tensor<16xf32>
84+
%13 = stablehlo.subtract %cst_0, %12 : tensor<16xf32>
85+
%14 = stablehlo.multiply %11, %13 : tensor<16xf32>
86+
%15 = stablehlo.dot_general %14, %arg0, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16xf32>, tensor<2x16xf32>) -> tensor<2xf32>
87+
return %15 : tensor<2xf32>
88+
}
89+
90+
// CHECK: func.func @main4(%arg0: tensor<2x16xf32> {enzymexla.memory_effects = []}, %arg1: tensor<16xf32> {enzymexla.memory_effects = []}, %arg2: tensor<16x16xf32> {enzymexla.memory_effects = []}, %arg3: tensor<16xf32> {enzymexla.memory_effects = []}, %arg4: tensor<16x1xf32> {enzymexla.memory_effects = []}, %arg5: tensor<1xf32> {enzymexla.memory_effects = []}, %arg6: tensor<2xf32> {enzymexla.memory_effects = []}) -> tensor<2xf32> attributes {enzymexla.memory_effects = []} {
91+
// CHECK-NEXT: %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
92+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg6, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x16xf32>, tensor<2xf32>) -> tensor<16xf32>
93+
// CHECK-NEXT: %1 = stablehlo.add %0, %arg1 : tensor<16xf32>
94+
// CHECK-NEXT: %2 = stablehlo.tanh %1 : tensor<16xf32>
95+
// CHECK-NEXT: %3 = stablehlo.dot_general %arg2, %2, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16xf32>) -> tensor<16xf32>
96+
// CHECK-NEXT: %4 = stablehlo.add %3, %arg3 : tensor<16xf32>
97+
// CHECK-NEXT: %5 = stablehlo.reshape %arg4 : (tensor<16x1xf32>) -> tensor<16xf32>
98+
// CHECK-NEXT: %6 = stablehlo.tanh %4 : tensor<16xf32>
99+
// CHECK-NEXT: %7 = stablehlo.multiply %6, %6 : tensor<16xf32>
100+
// CHECK-NEXT: %8 = stablehlo.subtract %cst, %7 : tensor<16xf32>
101+
// CHECK-NEXT: %9 = stablehlo.multiply %5, %8 : tensor<16xf32>
102+
// CHECK-NEXT: %10 = stablehlo.dot_general %9, %arg2, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16xf32>, tensor<16x16xf32>) -> tensor<16xf32>
103+
// CHECK-NEXT: %11 = stablehlo.multiply %2, %2 : tensor<16xf32>
104+
// CHECK-NEXT: %12 = stablehlo.subtract %cst, %11 : tensor<16xf32>
105+
// CHECK-NEXT: %13 = stablehlo.multiply %10, %12 : tensor<16xf32>
106+
// CHECK-NEXT: %14 = stablehlo.dot_general %13, %arg0, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<16xf32>, tensor<2x16xf32>) -> tensor<2xf32>
107+
// CHECK-NEXT: return %14 : tensor<2xf32>
108+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)