@@ -224,6 +224,11 @@ struct MatrixSpec
224
224
if (Layout == Layout_PackedA_RowMajor && SubGroupSize == SUB_GROUP_16 &&
225
225
Cols == 32 && BitWidth == BITS_16)
226
226
ContribBitWidth = BITS_16;
227
+ // Special case - when bitwidth is 16 and layout is accumulator, the contribBitWidth is also 16
228
+ if ((Layout == Layout_Accumulator_RowMajor ||
229
+ Layout == Layout_Accumulator_ColumnMajor) &&
230
+ BitWidth == BITS_16)
231
+ ContribBitWidth = BITS_16;
227
232
228
233
if (Order == Order_Vnni) assert (ContribBitWidth == BITS_32);
229
234
}
@@ -473,8 +478,8 @@ static string ImplementSmallLoad2DBlock(MatrixSpec spec, bool isChecked)
473
478
s += " long offset = as_long(mem);\n " ;
474
479
s += " int pack_factor = " + to_string (blockBitWidth / spec.BitWidth ) + " ;\n " ;
475
480
s += " int2 coords = (int2)(x / pack_factor, y);\n " ;
476
- s += " int width_bytes = " + to_string ( Bytes (spec. BitWidth )) + " * width - 1;\n " ;
477
- s += " int pitch_bytes = " + to_string ( Bytes (spec. BitWidth )) + " * stride - 1;\n " ;
481
+ s += " int width_bytes = ElemBytes * width - 1;\n " ;
482
+ s += " int pitch_bytes = ElemBytes * stride - 1;\n " ;
478
483
s += " int height_minus_one = height - 1;\n " ;
479
484
s += " ResultType BlockFunc(long, int, int, int, int2, int);\n " ;
480
485
s += " ResultType res = BlockFunc(offset, width_bytes, height_minus_one, "
@@ -489,7 +494,7 @@ static string ImplementSmallLoad2DBlock(MatrixSpec spec, bool isChecked)
489
494
s +=
490
495
" long x = (offset - baseoffset) / " + to_string (Bytes (blockBitWidth)) + " ;\n " ;
491
496
s += " int2 coords = (int2)(x, 0);\n " ;
492
- s += " int width_bytes = " + to_string ( Bytes (spec. BitWidth )) + " * stride - 1;\n " ;
497
+ s += " int width_bytes = ElemBytes * stride - 1;\n " ;
493
498
s += " int pitch_bytes = width_bytes;\n " ;
494
499
s += " int height_minus_one = " + to_string (blockHeight) + " - 1;\n " ;
495
500
s += " ResultType BlockFunc(long, int, int, int, int2, int);\n " ;
@@ -526,6 +531,7 @@ static string ImplementSmallLoad2DBlock(MatrixSpec spec, bool isChecked)
526
531
s = Replace (s, " SubGroupSize" , to_string (spec.SubGroupSize ));
527
532
s = Replace (s, " ResultType" , resultType);
528
533
s = Replace (s, " BlockFunc" , blockFunc);
534
+ s = Replace (s, " ElemBytes" , to_string (Bytes (spec.BitWidth )));
529
535
return s;
530
536
}
531
537
@@ -908,6 +914,7 @@ ImplementLargeLoadVectorContinuous(MatrixSpec spec, AddrSpace addr, int numLoads
908
914
s = Replace (s, " AddrSpace" , " __" + ToString (addr));
909
915
s = Replace (s, " ElemByteWidth" , to_string (Bytes (spec.BitWidth )));
910
916
s = Replace (s, " ContribByteWidth" , to_string (Bytes (spec.ContribBitWidth )));
917
+ s = Replace (s, " ElemBytes" , to_string (Bytes (spec.BitWidth )));
911
918
return s;
912
919
}
913
920
@@ -1087,7 +1094,8 @@ ImplementLargeLoadBase(MatrixSpec spec, AddrSpace addr, int numLoads, bool isChe
1087
1094
1088
1095
s = Replace (s, " LoadFunc" , loadFunc);
1089
1096
s = Replace (s, " AddrSpace" , " __" + ToString (addr));
1090
- s = Replace (s, " ElemByteWidth" , to_string (Bytes (spec.BitWidth )));
1097
+ s = Replace (s, " ElemByteWidth" , " ElemBytes" );
1098
+ s = Replace (s, " ElemBytes" , to_string (Bytes (spec.BitWidth )));
1091
1099
s = Replace (s, " ContribByteWidth" , to_string (Bytes (spec.ContribBitWidth )));
1092
1100
s = Replace (s, " WiRowsPerLoad" , to_string (wiRowsPerLoad));
1093
1101
s = Replace (s, " NumLoads" , to_string (numLoads));
@@ -1176,37 +1184,35 @@ static string DefineSpecialLarge1x64AddrSpace(MatrixSpec spec, AddrSpace addr)
1176
1184
string implBlock2D =
1177
1185
" if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL) {\n "
1178
1186
" long offset = as_long(mem);\n " // align to 64-byte
1179
- " long baseoffset = offset & (~0x3f);\n " // load 1x64 as 4x16, hence, width is 16 int in bytes
1180
- " int width_bytes = sizeof(int) * 16 - 1;\n " // load 1x64 as 4x16, hence, width is 16 int in bytes
1181
- " int height_minus_one = 4 - 1;\n " // row count
1187
+ " long baseoffset = offset & (~0x3f);\n " // load 1x64 as 4x16(32bit) or 2x32(16bit) , hence, width is 16 int in bytes
1188
+ " int width_bytes = ElemBytes * Width_1x64 - 1;\n " // load 1x64 as 4x16(32bit) or 2x32(16bit) , hence, width is 16 int in bytes
1189
+ " int height_minus_one = Height_1x64 - 1;\n " // row count
1182
1190
" int pitch_bytes = width_bytes;\n " // JointMatrices are expected to be contiguous in memory, without padding at the end of a row
1183
- " long x = (offset - baseoffset) / sizeof(int) ;\n " // in elements
1191
+ " long x = (offset - baseoffset) / ElemBytes ;\n " // in elements
1184
1192
" int2 coords = (int2)(x, 0);\n "
1185
- " uint4 __builtin_IB_subgroup_block_read_flat_u32_wi4_m4k16v1(long, int, int, "
1186
- " int, int2, int);\n "
1187
- " uint4 res = __builtin_IB_subgroup_block_read_flat_u32_wi4_m4k16v1(baseoffset, "
1188
- " width_bytes, height_minus_one, pitch_bytes, coords, cacheOpt);\n "
1189
- " *(__private uint4 *)dst = res;\n "
1193
+ " ElemType4 BlockLoadFunc(long, int, int, int, int2, int);\n "
1194
+ " ElemType4 res = BlockLoadFunc(baseoffset, width_bytes, height_minus_one, "
1195
+ " pitch_bytes, coords, cacheOpt);\n "
1196
+ " *(__private ElemType4 *)dst = res;\n "
1190
1197
" return;\n "
1191
1198
" }\n " ;
1192
1199
1193
1200
string implVectors =
1194
- " if(BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL) {\n "
1195
- " *(__private uint4 *)dst = intel_sub_group_block_read4((AddrSpace uint "
1196
- " *)mem);\n "
1197
- " return;\n "
1198
- " }\n "
1201
+ " if(BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL) { \n "
1202
+ " *(__private ElemType4 *) dst = VecFunc4((AddrSpace ElemType *)mem); \n "
1203
+ " return; \n "
1204
+ " } \n "
1199
1205
" if(BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_IMPL) {\n "
1200
- " __private uint *wi_contrib = (__private uint *)dst;\n "
1206
+ " __private ElemType *wi_contrib = (__private ElemType *)dst;\n "
1201
1207
" for (int i = 0; i < 4; i++)\n "
1202
- " wi_contrib[i] = intel_sub_group_block_read ((__global uint *)mem + i*16);\n "
1208
+ " wi_contrib[i] = VecFunc ((__global ElemType *)mem + i*16);\n "
1203
1209
" return;\n "
1204
1210
" }\n " ;
1205
1211
1206
1212
string implScalar =
1207
- " AddrSpace int *ptr = (AddrSpace uint *)mem;\n "
1213
+ " AddrSpace ElemType *ptr = (AddrSpace ElemType *)mem;\n "
1208
1214
" int slid = get_sub_group_local_id();\n "
1209
- " __private uint *wi_contrib = (__private uint *)dst;\n "
1215
+ " __private ElemType *wi_contrib = (__private ElemType *)dst;\n "
1210
1216
" for (int i = 0; i < 4; i++)\n "
1211
1217
" wi_contrib[i] = ptr[i*16 + slid];\n " ;
1212
1218
@@ -1244,8 +1250,34 @@ static string DefineSpecialLarge1x64AddrSpace(MatrixSpec spec, AddrSpace addr)
1244
1250
}
1245
1251
1246
1252
s += " }\n\n " ;
1247
-
1253
+ string vecFunc4 = " intel_sub_group_block_read" +
1254
+ GetVectorLoadSuffix (spec.ContribBitWidth ) +
1255
+ ToStringAbove1 (spec.WiRows );
1256
+ string vecFunc =
1257
+ " intel_sub_group_block_read" + GetVectorLoadSuffix (spec.ContribBitWidth );
1258
+ if (spec.BitWidth == 32 )
1259
+ {
1260
+ string blockLoadFunc =
1261
+ " __builtin_IB_subgroup_block_read_flat_uElemBits_wiWiRows_m4k16v1" ;
1262
+ s = Replace (s, " BlockLoadFunc" , blockLoadFunc);
1263
+ s = Replace (s, " Width_1x64" , string (" 16" ));
1264
+ s = Replace (s, " Height_1x64" , string (" 4" ));
1265
+ }
1266
+ else
1267
+ {
1268
+ string blockLoadFunc =
1269
+ " __builtin_IB_subgroup_block_read_flat_uElemBits_wiWiRows_m2k32v1" ;
1270
+ s = Replace (s, " BlockLoadFunc" , blockLoadFunc);
1271
+ s = Replace (s, " Width_1x64" , string (" 32" ));
1272
+ s = Replace (s, " Height_1x64" , string (" 2" ));
1273
+ }
1274
+ s = Replace (s, " ElemBits" , to_string (spec.BitWidth ));
1275
+ s = Replace (s, " VecFunc4" , vecFunc4);
1276
+ s = Replace (s, " VecFunc" , vecFunc);
1248
1277
s = Replace (s, " AddrSpace" , " __" + ToString (addr));
1278
+ s = Replace (s, " ElemBytes" , to_string (Bytes (spec.BitWidth )));
1279
+ s = Replace (s, " ElemType" , GetUnsignedType (spec.BitWidth ));
1280
+ s = Replace (s, " WiRows" , to_string (spec.WiRows ));
1249
1281
return s;
1250
1282
}
1251
1283
@@ -1284,7 +1316,8 @@ static string DefineSpecialLarge1x64(MatrixSpec spec)
1284
1316
s += " }\n\n " ;
1285
1317
}
1286
1318
}
1287
-
1319
+ s = Replace (s, " ElemBits" , to_string (spec.BitWidth ));
1320
+ s = Replace (s, " ElemBytes" , to_string (Bytes (spec.BitWidth )));
1288
1321
return s;
1289
1322
}
1290
1323
@@ -1380,6 +1413,10 @@ static string DefineAllSmallLoads()
1380
1413
MatrixSpec (SUB_GROUP_32, Layout_PackedB_RowMajor, 8 , 16 , BITS_32));
1381
1414
1382
1415
1416
+ // Acumulator, i16
1417
+ s += DefineSmallLoadPermuteRows (
1418
+ MatrixSpec (SUB_GROUP_16, Layout_Accumulator_RowMajor, 8 , 16 , BITS_16));
1419
+
1383
1420
// Accumulator, i32:
1384
1421
/* Load accumulator is a special case of load packed A, both are row major: */
1385
1422
s += DefineSmallLoadPermuteRows (
@@ -1417,6 +1454,12 @@ static string DefineAllSmallLoads()
1417
1454
s += DefineSmallLoad (
1418
1455
MatrixSpec (SUB_GROUP_16, Layout_PackedB_PackedB, 16 , 32 , BITS_16));
1419
1456
1457
+ // Accumulator, i16
1458
+ s += DefineSmallLoad (
1459
+ MatrixSpec (SUB_GROUP_16, Layout_Accumulator_RowMajor, 16 , 16 , BITS_16));
1460
+ s += DefineSmallLoad (
1461
+ MatrixSpec (SUB_GROUP_16, Layout_Accumulator_RowMajor, 32 , 32 , BITS_16));
1462
+
1420
1463
// Accumulator, i32:
1421
1464
s += DefineSmallLoad (
1422
1465
MatrixSpec (SUB_GROUP_8, Layout_Accumulator_RowMajor, 32 , 8 , BITS_32));
@@ -1463,13 +1506,21 @@ static string DefineAllLargeLoads()
1463
1506
s += DefineLargeLoad (
1464
1507
MatrixSpec (SUB_GROUP_16, Layout_Accumulator_RowMajor, 32 , 64 , BITS_32));
1465
1508
1509
+ // Accumulator, i16:
1510
+ s += DefineLargeLoad (
1511
+ MatrixSpec (SUB_GROUP_16, Layout_Accumulator_RowMajor, 32 , 64 , BITS_16));
1512
+
1466
1513
//
1467
1514
// Special large loads
1468
1515
//
1469
1516
1470
1517
// Accumulator, i32 - 1x64
1471
1518
s += DefineSpecialLarge1x64 (
1472
1519
MatrixSpec (SUB_GROUP_16, Layout_Accumulator_RowMajor, 1 , 64 , BITS_32));
1520
+
1521
+ // Accumulator, i16 - 1x64
1522
+ s += DefineSpecialLarge1x64 (
1523
+ MatrixSpec (SUB_GROUP_16, Layout_Accumulator_RowMajor, 1 , 64 , BITS_16));
1473
1524
return s;
1474
1525
}
1475
1526
0 commit comments