Skip to content

Commit bbdc8e9

Browse files
ggojskaigcbot
authored andcommitted
[IGC OCL] SYCL Joint Matrix enable 16-bit datatypes for C and D matrices.
Enable 16-bit datatypes for accumulator and output matrices in joint matrix. Platforms: PVC, DG2 Keywords: Feature Related-to: GSD-11139 Resolves:
1 parent 33a954c commit bbdc8e9

22 files changed

+832
-347
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/Matrix/IBiF_matrix.cl

Lines changed: 239 additions & 185 deletions
Large diffs are not rendered by default.

IGC/BiFModule/Languages/OpenCL/PreRelease/Matrix/IBiF_matrix_generator.cpp

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ struct MatrixSpec
224224
if (Layout == Layout_PackedA_RowMajor && SubGroupSize == SUB_GROUP_16 &&
225225
Cols == 32 && BitWidth == BITS_16)
226226
ContribBitWidth = BITS_16;
227+
// Special case - when bitwidth is 16 and layout is accumulator, the contribBitWidth is also 16
228+
if ((Layout == Layout_Accumulator_RowMajor ||
229+
Layout == Layout_Accumulator_ColumnMajor) &&
230+
BitWidth == BITS_16)
231+
ContribBitWidth = BITS_16;
227232

228233
if (Order == Order_Vnni) assert(ContribBitWidth == BITS_32);
229234
}
@@ -473,8 +478,8 @@ static string ImplementSmallLoad2DBlock(MatrixSpec spec, bool isChecked)
473478
s += "long offset = as_long(mem);\n";
474479
s += "int pack_factor = " + to_string(blockBitWidth / spec.BitWidth) + ";\n";
475480
s += "int2 coords = (int2)(x / pack_factor, y);\n";
476-
s += "int width_bytes = " + to_string(Bytes(spec.BitWidth)) + " * width - 1;\n";
477-
s += "int pitch_bytes = " + to_string(Bytes(spec.BitWidth)) + " * stride - 1;\n";
481+
s += "int width_bytes = ElemBytes * width - 1;\n";
482+
s += "int pitch_bytes = ElemBytes * stride - 1;\n";
478483
s += "int height_minus_one = height - 1;\n";
479484
s += "ResultType BlockFunc(long, int, int, int, int2, int);\n";
480485
s += "ResultType res = BlockFunc(offset, width_bytes, height_minus_one, "
@@ -489,7 +494,7 @@ static string ImplementSmallLoad2DBlock(MatrixSpec spec, bool isChecked)
489494
s +=
490495
"long x = (offset - baseoffset) / " + to_string(Bytes(blockBitWidth)) + ";\n";
491496
s += "int2 coords = (int2)(x, 0);\n";
492-
s += "int width_bytes = " + to_string(Bytes(spec.BitWidth)) + " * stride - 1;\n";
497+
s += "int width_bytes = ElemBytes * stride - 1;\n";
493498
s += "int pitch_bytes = width_bytes;\n";
494499
s += "int height_minus_one = " + to_string(blockHeight) + " - 1;\n";
495500
s += "ResultType BlockFunc(long, int, int, int, int2, int);\n";
@@ -526,6 +531,7 @@ static string ImplementSmallLoad2DBlock(MatrixSpec spec, bool isChecked)
526531
s = Replace(s, "SubGroupSize", to_string(spec.SubGroupSize));
527532
s = Replace(s, "ResultType", resultType);
528533
s = Replace(s, "BlockFunc", blockFunc);
534+
s = Replace(s, "ElemBytes", to_string(Bytes(spec.BitWidth)));
529535
return s;
530536
}
531537

@@ -908,6 +914,7 @@ ImplementLargeLoadVectorContinuous(MatrixSpec spec, AddrSpace addr, int numLoads
908914
s = Replace(s, "AddrSpace", "__" + ToString(addr));
909915
s = Replace(s, "ElemByteWidth", to_string(Bytes(spec.BitWidth)));
910916
s = Replace(s, "ContribByteWidth", to_string(Bytes(spec.ContribBitWidth)));
917+
s = Replace(s, "ElemBytes", to_string(Bytes(spec.BitWidth)));
911918
return s;
912919
}
913920

@@ -1087,7 +1094,8 @@ ImplementLargeLoadBase(MatrixSpec spec, AddrSpace addr, int numLoads, bool isChe
10871094

10881095
s = Replace(s, "LoadFunc", loadFunc);
10891096
s = Replace(s, "AddrSpace", "__" + ToString(addr));
1090-
s = Replace(s, "ElemByteWidth", to_string(Bytes(spec.BitWidth)));
1097+
s = Replace(s, "ElemByteWidth", "ElemBytes");
1098+
s = Replace(s, "ElemBytes", to_string(Bytes(spec.BitWidth)));
10911099
s = Replace(s, "ContribByteWidth", to_string(Bytes(spec.ContribBitWidth)));
10921100
s = Replace(s, "WiRowsPerLoad", to_string(wiRowsPerLoad));
10931101
s = Replace(s, "NumLoads", to_string(numLoads));
@@ -1176,37 +1184,35 @@ static string DefineSpecialLarge1x64AddrSpace(MatrixSpec spec, AddrSpace addr)
11761184
string implBlock2D =
11771185
"if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL) {\n"
11781186
" long offset = as_long(mem);\n" // align to 64-byte
1179-
" long baseoffset = offset & (~0x3f);\n" // load 1x64 as 4x16, hence, width is 16 int in bytes
1180-
" int width_bytes = sizeof(int) * 16 - 1;\n" // load 1x64 as 4x16, hence, width is 16 int in bytes
1181-
" int height_minus_one = 4 - 1;\n" // row count
1187+
" long baseoffset = offset & (~0x3f);\n" // load 1x64 as 4x16(32bit) or 2x32(16bit), hence, width is 16 int in bytes
1188+
" int width_bytes = ElemBytes * Width_1x64 - 1;\n" // load 1x64 as 4x16(32bit) or 2x32(16bit), hence, width is 16 int in bytes
1189+
" int height_minus_one = Height_1x64 - 1;\n" // row count
11821190
" int pitch_bytes = width_bytes;\n" // JointMatrices are expected to be contiguous in memory, without padding at the end of a row
1183-
" long x = (offset - baseoffset) / sizeof(int);\n" // in elements
1191+
" long x = (offset - baseoffset) / ElemBytes;\n" // in elements
11841192
" int2 coords = (int2)(x, 0);\n"
1185-
" uint4 __builtin_IB_subgroup_block_read_flat_u32_wi4_m4k16v1(long, int, int, "
1186-
"int, int2, int);\n"
1187-
" uint4 res = __builtin_IB_subgroup_block_read_flat_u32_wi4_m4k16v1(baseoffset, "
1188-
"width_bytes, height_minus_one, pitch_bytes, coords, cacheOpt);\n"
1189-
" *(__private uint4 *)dst = res;\n"
1193+
" ElemType4 BlockLoadFunc(long, int, int, int, int2, int);\n"
1194+
" ElemType4 res = BlockLoadFunc(baseoffset, width_bytes, height_minus_one, "
1195+
"pitch_bytes, coords, cacheOpt);\n"
1196+
" *(__private ElemType4 *)dst = res;\n"
11901197
" return;\n"
11911198
"}\n";
11921199

11931200
string implVectors =
1194-
"if(BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL) {\n"
1195-
" *(__private uint4 *)dst = intel_sub_group_block_read4((AddrSpace uint "
1196-
"*)mem);\n"
1197-
" return;\n"
1198-
"}\n"
1201+
"if(BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL) { \n"
1202+
" *(__private ElemType4 *) dst = VecFunc4((AddrSpace ElemType *)mem); \n"
1203+
" return; \n"
1204+
"} \n"
11991205
"if(BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_IMPL) {\n"
1200-
" __private uint *wi_contrib = (__private uint *)dst;\n"
1206+
" __private ElemType *wi_contrib = (__private ElemType *)dst;\n"
12011207
" for (int i = 0; i < 4; i++)\n"
1202-
" wi_contrib[i] = intel_sub_group_block_read((__global uint *)mem + i*16);\n"
1208+
" wi_contrib[i] = VecFunc((__global ElemType *)mem + i*16);\n"
12031209
" return;\n"
12041210
"}\n";
12051211

12061212
string implScalar =
1207-
"AddrSpace int *ptr = (AddrSpace uint *)mem;\n"
1213+
"AddrSpace ElemType *ptr = (AddrSpace ElemType *)mem;\n"
12081214
"int slid = get_sub_group_local_id();\n"
1209-
"__private uint *wi_contrib = (__private uint *)dst;\n"
1215+
"__private ElemType *wi_contrib = (__private ElemType *)dst;\n"
12101216
"for (int i = 0; i < 4; i++)\n"
12111217
" wi_contrib[i] = ptr[i*16 + slid];\n";
12121218

@@ -1244,8 +1250,34 @@ static string DefineSpecialLarge1x64AddrSpace(MatrixSpec spec, AddrSpace addr)
12441250
}
12451251

12461252
s += "}\n\n";
1247-
1253+
string vecFunc4 = "intel_sub_group_block_read" +
1254+
GetVectorLoadSuffix(spec.ContribBitWidth) +
1255+
ToStringAbove1(spec.WiRows);
1256+
string vecFunc =
1257+
"intel_sub_group_block_read" + GetVectorLoadSuffix(spec.ContribBitWidth);
1258+
if (spec.BitWidth == 32)
1259+
{
1260+
string blockLoadFunc =
1261+
"__builtin_IB_subgroup_block_read_flat_uElemBits_wiWiRows_m4k16v1";
1262+
s = Replace(s, "BlockLoadFunc", blockLoadFunc);
1263+
s = Replace(s, "Width_1x64", string("16"));
1264+
s = Replace(s, "Height_1x64", string("4"));
1265+
}
1266+
else
1267+
{
1268+
string blockLoadFunc =
1269+
"__builtin_IB_subgroup_block_read_flat_uElemBits_wiWiRows_m2k32v1";
1270+
s = Replace(s, "BlockLoadFunc", blockLoadFunc);
1271+
s = Replace(s, "Width_1x64", string("32"));
1272+
s = Replace(s, "Height_1x64", string("2"));
1273+
}
1274+
s = Replace(s, "ElemBits", to_string(spec.BitWidth));
1275+
s = Replace(s, "VecFunc4", vecFunc4);
1276+
s = Replace(s, "VecFunc", vecFunc);
12481277
s = Replace(s, "AddrSpace", "__" + ToString(addr));
1278+
s = Replace(s, "ElemBytes", to_string(Bytes(spec.BitWidth)));
1279+
s = Replace(s, "ElemType", GetUnsignedType(spec.BitWidth));
1280+
s = Replace(s, "WiRows", to_string(spec.WiRows));
12491281
return s;
12501282
}
12511283

@@ -1284,7 +1316,8 @@ static string DefineSpecialLarge1x64(MatrixSpec spec)
12841316
s += "}\n\n";
12851317
}
12861318
}
1287-
1319+
s = Replace(s, "ElemBits", to_string(spec.BitWidth));
1320+
s = Replace(s, "ElemBytes", to_string(Bytes(spec.BitWidth)));
12881321
return s;
12891322
}
12901323

@@ -1380,6 +1413,10 @@ static string DefineAllSmallLoads()
13801413
MatrixSpec(SUB_GROUP_32, Layout_PackedB_RowMajor, 8, 16, BITS_32));
13811414

13821415

1416+
// Acumulator, i16
1417+
s += DefineSmallLoadPermuteRows(
1418+
MatrixSpec(SUB_GROUP_16, Layout_Accumulator_RowMajor, 8, 16, BITS_16));
1419+
13831420
// Accumulator, i32:
13841421
/* Load accumulator is a special case of load packed A, both are row major: */
13851422
s += DefineSmallLoadPermuteRows(
@@ -1417,6 +1454,12 @@ static string DefineAllSmallLoads()
14171454
s += DefineSmallLoad(
14181455
MatrixSpec(SUB_GROUP_16, Layout_PackedB_PackedB, 16, 32, BITS_16));
14191456

1457+
// Accumulator, i16
1458+
s += DefineSmallLoad(
1459+
MatrixSpec(SUB_GROUP_16, Layout_Accumulator_RowMajor, 16, 16, BITS_16));
1460+
s += DefineSmallLoad(
1461+
MatrixSpec(SUB_GROUP_16, Layout_Accumulator_RowMajor, 32, 32, BITS_16));
1462+
14201463
// Accumulator, i32:
14211464
s += DefineSmallLoad(
14221465
MatrixSpec(SUB_GROUP_8, Layout_Accumulator_RowMajor, 32, 8, BITS_32));
@@ -1463,13 +1506,21 @@ static string DefineAllLargeLoads()
14631506
s += DefineLargeLoad(
14641507
MatrixSpec(SUB_GROUP_16, Layout_Accumulator_RowMajor, 32, 64, BITS_32));
14651508

1509+
// Accumulator, i16:
1510+
s += DefineLargeLoad(
1511+
MatrixSpec(SUB_GROUP_16, Layout_Accumulator_RowMajor, 32, 64, BITS_16));
1512+
14661513
//
14671514
// Special large loads
14681515
//
14691516

14701517
// Accumulator, i32 - 1x64
14711518
s += DefineSpecialLarge1x64(
14721519
MatrixSpec(SUB_GROUP_16, Layout_Accumulator_RowMajor, 1, 64, BITS_32));
1520+
1521+
// Accumulator, i16 - 1x64
1522+
s += DefineSpecialLarge1x64(
1523+
MatrixSpec(SUB_GROUP_16, Layout_Accumulator_RowMajor, 1, 64, BITS_16));
14731524
return s;
14741525
}
14751526

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20645,6 +20645,15 @@ void EmitPass::emitDpas(GenIntrinsicInst *GII, const SSource *Sources, const Dst
2064520645
if (input->GetType() == ISA_TYPE_UW || input->GetType() == ISA_TYPE_W) {
2064620646
input = m_currShader->GetNewAlias(input, ISA_TYPE_BF, 0, 0);
2064720647
}
20648+
// For matrix C shape 1x64 i16 data types data is not
20649+
// properly aligned by default so we have to do a copy.
20650+
if (RC == 1 && (input->GetType() == ISA_TYPE_BF || input->GetType() == ISA_TYPE_HF)) {
20651+
CVariable *input_tmp = m_currShader->GetNewVariable(input->GetNumberElement(), input->GetType(), EALIGN_GRF,
20652+
false /*uniform*/, "input_realign");
20653+
m_encoder->Copy(input_tmp, input);
20654+
input = input_tmp;
20655+
m_encoder->Push();
20656+
}
2064820657
}
2064920658
if (dst->GetType() == ISA_TYPE_UW || dst->GetType() == ISA_TYPE_W) {
2065020659
dst = m_currShader->GetNewAlias(dst, ISA_TYPE_BF, 0, 0);

0 commit comments

Comments
 (0)