Skip to content

Commit e430795

Browse files
Fix MlasSgemmKernel: properly process more than 2 rows (#22125)
This change fixes multiple tests like QDQTransformerTests.MatMul_U8S8S8, for all architectures where architecture-specific optimized function is not available yet, like s390x. ### Description Matrix B is packed by 16 elements, thus new row starts 16 items later. Also, for next C increment index only by 1 for each increment of C. ### Motivation and Context This change fixes mlas sgemm fallback implementation for all architectures which don't have architecture-specific implementations available, like s390x.
1 parent 712bee1 commit e430795

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

onnxruntime/core/mlas/lib/scalar/SgemmKernelScalar.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ Return Value:
8383

8484
#endif
8585

86+
int countb = 0;
87+
8688
do {
8789

8890
float BElements00;
@@ -116,6 +118,7 @@ Return Value:
116118
//
117119

118120
const float* a = A;
121+
const float* b = B;
119122
size_t k = CountK;
120123

121124
while (k >= 2) {
@@ -128,10 +131,10 @@ Return Value:
128131
Row1AElements1 = a[lda + 1];
129132
}
130133

131-
BElements00 = B[0];
132-
BElements01 = B[1];
133-
BElements02 = B[2];
134-
BElements03 = B[3];
134+
BElements00 = b[0];
135+
BElements01 = b[1];
136+
BElements02 = b[2];
137+
BElements03 = b[3];
135138
Row0Block00 = Row0Block00 + BElements00 * Row0AElements0;
136139
Row0Block01 = Row0Block01 + BElements01 * Row0AElements0;
137140
Row0Block02 = Row0Block02 + BElements02 * Row0AElements0;
@@ -144,10 +147,10 @@ Return Value:
144147
Row1Block03 = Row1Block03 + BElements03 * Row1AElements0;
145148
}
146149

147-
BElements00 = B[4];
148-
BElements01 = B[5];
149-
BElements02 = B[6];
150-
BElements03 = B[7];
150+
BElements00 = b[16];
151+
BElements01 = b[17];
152+
BElements02 = b[18];
153+
BElements03 = b[19];
151154
Row0Block00 = Row0Block00 + BElements00 * Row0AElements1;
152155
Row0Block01 = Row0Block01 + BElements01 * Row0AElements1;
153156
Row0Block02 = Row0Block02 + BElements02 * Row0AElements1;
@@ -161,7 +164,7 @@ Return Value:
161164
}
162165

163166
a += 2;
164-
B += 8;
167+
b += 32;
165168
k -= 2;
166169
}
167170

@@ -173,10 +176,10 @@ Return Value:
173176
Row1AElements0 = a[lda];
174177
}
175178

176-
BElements00 = B[0];
177-
BElements01 = B[1];
178-
BElements02 = B[2];
179-
BElements03 = B[3];
179+
BElements00 = b[0];
180+
BElements01 = b[1];
181+
BElements02 = b[2];
182+
BElements03 = b[3];
180183
Row0Block00 = Row0Block00 + BElements00 * Row0AElements0;
181184
Row0Block01 = Row0Block01 + BElements01 * Row0AElements0;
182185
Row0Block02 = Row0Block02 + BElements02 * Row0AElements0;
@@ -188,8 +191,6 @@ Return Value:
188191
Row1Block02 = Row1Block02 + BElements02 * Row1AElements0;
189192
Row1Block03 = Row1Block03 + BElements03 * Row1AElements0;
190193
}
191-
192-
B += 4;
193194
}
194195

195196
//
@@ -295,9 +296,14 @@ Return Value:
295296
break;
296297
}
297298

299+
B += 4;
298300
C += 4;
299301
CountN -= 4;
300302

303+
countb = (countb + 1) % 4;
304+
if (countb == 0) {
305+
B += CountK * 16 - 16;
306+
}
301307
} while (CountN > 0);
302308

303309
return ProcessTwoRows ? 2 : 1;

0 commit comments

Comments
 (0)