@@ -758,6 +758,9 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
758
758
packBIdx = 0
759
759
packMIdx = 0
760
760
761
+ numPackedA = 0
762
+ numPackedB = 0
763
+ numPackedM = 0
761
764
#####
762
765
# Prepare localReadCode
763
766
####
@@ -832,11 +835,19 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
832
835
833
836
instPerPackM = 0
834
837
if kernel ["ProblemType" ]["Sparse" ] and not kernel ["DirectToVgprSparseMetadata" ] and not kernel ["UnrollMajorLDSMetadata" ]:
835
- instPerPackM = 1.5 if self .states .lrvwTileMetadata > 1 and kernel ["MIInputPerThreadMetadata" ] == 1 else 1
836
-
837
- packItems = []
838
+ instPerPackM = 1
839
+ if self .states .lrvwTileMetadata > 1 :
840
+ if kernel ["MIInputPerThreadMetadata" ] == 1 :
841
+ instPerPackM = 1.5
842
+ elif kernel ["MIInputPerThreadMetadata" ] == 4 :
843
+ instPerPackM = 3
844
+ packItemsA = []
845
+ packItemsB = []
846
+ packItemsM = []
838
847
for iui in range (kernel ["InnerUnroll" ]):
839
- packINtems = [ [] for j in range (max (self .states .numReadsIterCoalescedA ,self .states .numReadsIterCoalescedB ,self .states .numReadsIterCoalescedMetadata )) ]
848
+ packINtemsA = [ [] for j in range (max (self .states .numReadsIterCoalescedA ,self .states .numReadsIterCoalescedB ,self .states .numReadsIterCoalescedMetadata )) ]
849
+ packINtemsB = [ [] for j in range (max (self .states .numReadsIterCoalescedA ,self .states .numReadsIterCoalescedB ,self .states .numReadsIterCoalescedMetadata )) ]
850
+ packINtemsM = [ [] for j in range (max (self .states .numReadsIterCoalescedA ,self .states .numReadsIterCoalescedB ,self .states .numReadsIterCoalescedMetadata )) ]
840
851
packA = packCode .findNamedItem ("packA_I%s" % (iui ))
841
852
packB = packCode .findNamedItem ("packB_I%s" % (iui ))
842
853
packM = packCode .findNamedItem ("packMetadata_I%s" % (iui ))
@@ -856,58 +867,72 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
856
867
if packAItems :
857
868
if kernel ["ConvertAfterDS" ] and kernel ["ProblemType" ]["DataTypeA" ].isFloat8 ():
858
869
for n in range (instPerPackA ):
859
- packINtems [0 ].append (packAItems .pop (0 ))
870
+ packINtemsA [0 ].append (packAItems .pop (0 ))
860
871
else :
861
872
for j in range (self .states .numReadsIterCoalescedA ):
862
873
for n in range (instPerPackA ):
863
- packINtems [j ].append (packAItems .pop (0 ))
874
+ packINtemsA [j ].append (packAItems .pop (0 ))
864
875
865
876
if kernel ["ProblemType" ]["Sparse" ] and not kernel ["DirectToVgprSparseMetadata" ]:
866
877
for j in range (self .states .numReadsIterCoalescedMetadata ):
867
878
for n in range (ceil (instPerPackM )):
868
879
if packMItems :
869
- packINtems [j ].append (packMItems .pop (0 ))
880
+ packINtemsM [j ].append (packMItems .pop (0 ))
870
881
else :
871
882
break
872
883
873
884
if packBItems :
874
885
if kernel ["ConvertAfterDS" ] and kernel ["ProblemType" ]["DataTypeB" ].isFloat8 ():
875
886
for n in range (instPerPackB ):
876
- packINtems [0 ].append (packBItems .pop (0 ))
887
+ packINtemsB [0 ].append (packBItems .pop (0 ))
877
888
else :
878
889
for j in range (self .states .numReadsIterCoalescedB ):
879
890
for n in range (instPerPackB ):
880
- packINtems [j ].append (packBItems .pop (0 ))
891
+ packINtemsB [j ].append (packBItems .pop (0 ))
881
892
882
893
while packAItems :
883
894
if kernel ["ConvertAfterDS" ] and kernel ["ProblemType" ]["DataTypeA" ].isFloat8 ():
884
895
for n in range (instPerPackA ):
885
- packINtems [0 ].append (packAItems .pop (0 ))
896
+ if packAItems :
897
+ packINtemsA [0 ].append (packAItems .pop (0 ))
898
+ else :
899
+ break
886
900
else :
887
901
for j in range (self .states .numReadsIterCoalescedA ):
888
902
for n in range (instPerPackA ):
889
- packINtems [j ].append (packAItems .pop (0 ))
903
+ if packAItems :
904
+ packINtemsA [j ].append (packAItems .pop (0 ))
905
+ else :
906
+ break
890
907
891
- if kernel [ "ProblemType" ][ "Sparse" ] and not kernel [ "DirectToVgprSparseMetadata" ] :
892
- while packMItems :
908
+ while packMItems :
909
+ if kernel [ "ProblemType" ][ "Sparse" ] and not kernel [ "DirectToVgprSparseMetadata" ] :
893
910
for j in range (self .states .numReadsIterCoalescedMetadata ):
894
911
for n in range (ceil (instPerPackM )):
895
912
if packMItems :
896
- packINtems [j ].append (packMItems .pop (0 ))
913
+ packINtemsM [j ].append (packMItems .pop (0 ))
897
914
else :
898
915
break
899
916
900
917
while packBItems :
901
918
if kernel ["ConvertAfterDS" ] and kernel ["ProblemType" ]["DataTypeB" ].isFloat8 ():
902
919
for n in range (instPerPackB ):
903
- packINtems [0 ].append (packBItems .pop (0 ))
920
+ if packBItems :
921
+ packINtemsB [0 ].append (packBItems .pop (0 ))
922
+ else :
923
+ break
904
924
else :
905
925
for j in range (self .states .numReadsIterCoalescedB ):
906
926
for n in range (instPerPackB ):
907
- packINtems [j ].append (packBItems .pop (0 ))
927
+ if packBItems :
928
+ packINtemsB [j ].append (packBItems .pop (0 ))
929
+ else :
930
+ break
908
931
909
- for j in range (max (self .states .numReadsIterCoalescedA ,self .states .numReadsIterCoalescedB )):
910
- packItems += packINtems .pop (0 )
932
+ for j in range (max (self .states .numReadsIterCoalescedA ,self .states .numReadsIterCoalescedB ,self .states .numReadsIterCoalescedMetadata )):
933
+ packItemsA += packINtemsA .pop (0 )
934
+ packItemsB += packINtemsB .pop (0 )
935
+ packItemsM += packINtemsM .pop (0 )
911
936
912
937
# remove s_nop for packing
913
938
# we will add s_nop if needed
@@ -1053,7 +1078,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
1053
1078
numLocalReadShouldSchedule = 0
1054
1079
# prefetch load for next wave tile along M since we re-use B first.
1055
1080
tileM : int = kernel ["MIWaveTileA" ]
1056
- instsToCheck = mfmas [i :min (i + tileM + 1 , numMfmaPerIter )] + packItems
1081
+ instsToCheck = mfmas [i :min (i + tileM + 1 , numMfmaPerIter )] + packItemsA + packItemsM + packItemsB
1057
1082
localReadItemsThisLoop = sorted (localReadItemsThisLoop , key = lambda o : hasAnyDependency (o , instsToCheck ), reverse = True )
1058
1083
1059
1084
for lr in localReadItemsThisLoop :
@@ -1241,7 +1266,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
1241
1266
mfmas = [mfma for mfma in macIterCode .flatitems () if isinstance (mfma , (MFMAInstruction , SMFMAInstruction ,))]
1242
1267
## To support do["MAC"] is False
1243
1268
mfma = [mfmas [i ],] if len (mfmas ) > 0 else []
1244
- instsToCheck = mfma + packItems
1269
+ instsToCheck = mfma + packItemsA + packItemsM + packItemsB
1245
1270
numDsInsts = 0
1246
1271
lastLgkmCnt = - 1
1247
1272
for ds in filter (lambda j : isinstance (j , (DSLoadInstruction , DSStoreInstruction , SWaitCnt )), reversed (prevIterCode .flatitems () + iterCode .flatitems ())):
@@ -1271,18 +1296,25 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
1271
1296
####
1272
1297
# scheduled pack
1273
1298
####
1274
- if packItems :
1299
+ _instPerPackA = 0
1300
+ _instPerPackB = 0
1301
+ _instPerPackM = 0
1302
+ if packItemsA or packItemsB or packItemsM :
1275
1303
# how many pack have to be done
1276
1304
# calculate the data index of this mfma used for A and B
1277
1305
# if i // kernel["MIWaveTile"][0]==0, mfma will use new A (need to take iu into account)
1278
1306
# if i % kernel["MIWaveTile"][0]==0, mfma will use new B
1279
- packAIdx += instPerPackA if i // (kernel ["MIWaveTileA" ]+ kernel ["MIWaveTileA" ]* kernel ["MIWaveTileB" ]* (i // (kernel ["MIWaveTileA" ]* kernel ["MIWaveTileB" ]))) == 0 else 0
1280
- packBIdx += instPerPackB if i % kernel ["MIWaveTileA" ] == 0 else 0
1307
+ _instPerPackA = instPerPackA if i // (kernel ["MIWaveTileA" ]+ kernel ["MIWaveTileA" ]* kernel ["MIWaveTileB" ]* (i // (kernel ["MIWaveTileA" ]* kernel ["MIWaveTileB" ]))) == 0 else 0
1308
+ packAIdx += _instPerPackA
1309
+ _instPerPackB = instPerPackB if i % kernel ["MIWaveTileA" ] == 0 else 0
1310
+ packBIdx += _instPerPackB
1281
1311
if kernel ["ProblemType" ]["Sparse" ] and not kernel ["DirectToVgprSparseMetadata" ]:
1282
1312
if kernel ["ProblemType" ]["Sparse" ] == 2 :
1283
- packMIdx + = instPerPackM if i % kernel ["MIWaveTileA" ] == 0 else 0
1313
+ _instPerPackM = instPerPackM if i % kernel ["MIWaveTileA" ] == 0 else 0
1284
1314
else :
1285
- packMIdx += instPerPackM if i // (kernel ["MIWaveTileA" ]+ kernel ["MIWaveTileA" ]* kernel ["MIWaveTileB" ]* (i // (kernel ["MIWaveTileA" ]* kernel ["MIWaveTileB" ]))) == 0 else 0
1315
+ _instPerPackM = instPerPackM if i // (kernel ["MIWaveTileA" ]+ kernel ["MIWaveTileA" ]* kernel ["MIWaveTileB" ]* (i // (kernel ["MIWaveTileA" ]* kernel ["MIWaveTileB" ]))) == 0 else 0
1316
+
1317
+ packMIdx += _instPerPackM
1286
1318
# blockWidth < 1, means 0.5 or 0.25 (BF,H,Int8)
1287
1319
if self .states .archCaps ["HasEccHalf" ] or not self .states .asmCaps ["HasWMMA_V1" ]:
1288
1320
packAIdx = packAIdx if tPA ["bpe" ] < 4 and (not kernel ["UnrollMajorLDSA" ] or kernel ["ConvertAfterDS" ]) else 0
@@ -1298,33 +1330,70 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
1298
1330
iterCode .addComment0 ("pack scheduling: packAIdx:%u, packBIdx:%u, packMIdx:%u" % (packAIdx ,packBIdx ,packMIdx ))
1299
1331
else :
1300
1332
iterCode .addComment0 ("pack scheduling: packAIdx:%u, packBIdx:%u" % (packAIdx ,packBIdx ))
1301
- # we put 2 pack in each mfma
1302
- for j in range (instPerPackA ):
1303
- if packItems :
1304
- iterCode .add (packItems .pop (0 ))
1333
+
1334
+ # put the required pack into mfma
1335
+ for j in range (_instPerPackA ):
1336
+ if packItemsA :
1337
+ # Skip if the required pack has already been placed in the previous mfma iter.
1338
+ if numPackedA >= packAIdx :
1339
+ break
1340
+ iterCode .add (packItemsA .pop (0 ))
1305
1341
curPackIdx += 1
1342
+ numPackedA += 1
1343
+
1306
1344
if kernel ["ProblemType" ]["Sparse" ] and not kernel ["DirectToVgprSparseMetadata" ]:
1307
- for j in range (ceil (instPerPackM )):
1308
- if packItems :
1309
- iterCode .add (packItems .pop (0 ))
1345
+ for j in range (ceil (_instPerPackM )):
1346
+ if packItemsM :
1347
+ # Skip if the required pack has already been placed in the previous mfma iter.
1348
+ if numPackedM >= packMIdx :
1349
+ break
1350
+ iterCode .add (packItemsM .pop (0 ))
1310
1351
curPackIdx += 1
1311
- for j in range (instPerPackB ):
1312
- if packItems :
1313
- iterCode .add (packItems .pop (0 ))
1314
- curPackIdx += 1
1315
- # since packed register need to wait 2 quad cycle to finish packing
1316
- # we insert pack instruction if we can, or s_nop
1317
- while curPackIdx < numPack + 2 :
1318
- if packItems :
1319
- iterCode .add (packItems .pop (0 ))
1320
- curPackIdx += 1
1321
- else :
1322
- iterCode .add (SNop (waitState = 1 , comment = "VALU packing writes to be consumed by matrix instruction" ))
1352
+ numPackedM += 1
1353
+
1354
+ for j in range (_instPerPackB ):
1355
+ if packItemsB :
1356
+ # Skip if the required pack has already been placed in the previous mfma iter.
1357
+ if numPackedB >= packBIdx :
1358
+ break
1359
+ iterCode .add (packItemsB .pop (0 ))
1323
1360
curPackIdx += 1
1324
- break
1361
+ numPackedB += 1
1362
+
1363
+ # put unnecessary pack into mfma to fulfill the latency
1364
+ remainLatency = 2
1365
+ if curPackIdx < numPack + 2 :
1366
+ # since packed register need to wait 2 quad cycle to finish packing
1367
+ # we insert pack instruction if we can, or s_nop
1368
+ while remainLatency :
1369
+ if packItemsA :
1370
+ iterCode .add (packItemsA .pop (0 ))
1371
+ curPackIdx += 1
1372
+ numPackedA += 1
1373
+ remainLatency -= 1
1374
+ elif packItemsM :
1375
+ iterCode .add (packItemsM .pop (0 ))
1376
+ curPackIdx += 1
1377
+ numPackedM += 1
1378
+ remainLatency -= 1
1379
+ elif packItemsB :
1380
+ iterCode .add (packItemsB .pop (0 ))
1381
+ curPackIdx += 1
1382
+ numPackedB += 1
1383
+ remainLatency -= 1
1384
+ else :
1385
+ latency = remainLatency - 1
1386
+ iterCode .add (SNop (waitState = latency , comment = "VALU packing writes to be consumed by matrix instruction" ))
1387
+ curPackIdx += 1
1388
+ remainLatency -= (latency + 1 )
1389
+
1325
1390
if i == numMfmaPerIter - 1 :
1326
- while packItems :
1327
- iterCode .add (packItems .pop (0 ))
1391
+ while packItemsA :
1392
+ iterCode .add (packItemsA .pop (0 ))
1393
+ while packItemsM :
1394
+ iterCode .add (packItemsM .pop (0 ))
1395
+ while packItemsB :
1396
+ iterCode .add (packItemsB .pop (0 ))
1328
1397
1329
1398
####
1330
1399
# scheduled mfma dependency
0 commit comments