Skip to content

Commit 14bd624

Browse files
committed
Refactoy the pack scheduling for scheduleIterAlg = 3.
* Used 3 different pack pools to store the pack instructions of A, B, and Metadata * First, only put the required pack into the code (the number of required packs may differ for each mfma iteration). Second, put another pack or SNop before the mfma instruction according to the needed latency. the combination of insertion may be 2 packs, 1 pack + snop 0, or snop 1.
1 parent 3ce620f commit 14bd624

File tree

1 file changed

+116
-47
lines changed

1 file changed

+116
-47
lines changed

tensilelite/Tensile/KernelWriter.py

Lines changed: 116 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,9 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
758758
packBIdx = 0
759759
packMIdx = 0
760760

761+
numPackedA = 0
762+
numPackedB = 0
763+
numPackedM = 0
761764
#####
762765
# Prepare localReadCode
763766
####
@@ -832,11 +835,19 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
832835

833836
instPerPackM = 0
834837
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and not kernel["UnrollMajorLDSMetadata"]:
835-
instPerPackM = 1.5 if self.states.lrvwTileMetadata > 1 and kernel["MIInputPerThreadMetadata"] == 1 else 1
836-
837-
packItems = []
838+
instPerPackM = 1
839+
if self.states.lrvwTileMetadata > 1:
840+
if kernel["MIInputPerThreadMetadata"] == 1:
841+
instPerPackM = 1.5
842+
elif kernel["MIInputPerThreadMetadata"] == 4:
843+
instPerPackM = 3
844+
packItemsA = []
845+
packItemsB = []
846+
packItemsM = []
838847
for iui in range(kernel["InnerUnroll"]):
839-
packINtems = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
848+
packINtemsA = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
849+
packINtemsB = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
850+
packINtemsM = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
840851
packA = packCode.findNamedItem("packA_I%s"%(iui))
841852
packB = packCode.findNamedItem("packB_I%s"%(iui))
842853
packM = packCode.findNamedItem("packMetadata_I%s"%(iui))
@@ -856,58 +867,72 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
856867
if packAItems:
857868
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8():
858869
for n in range(instPerPackA):
859-
packINtems[0].append(packAItems.pop(0))
870+
packINtemsA[0].append(packAItems.pop(0))
860871
else:
861872
for j in range(self.states.numReadsIterCoalescedA):
862873
for n in range(instPerPackA):
863-
packINtems[j].append(packAItems.pop(0))
874+
packINtemsA[j].append(packAItems.pop(0))
864875

865876
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
866877
for j in range(self.states.numReadsIterCoalescedMetadata):
867878
for n in range(ceil(instPerPackM)):
868879
if packMItems:
869-
packINtems[j].append(packMItems.pop(0))
880+
packINtemsM[j].append(packMItems.pop(0))
870881
else:
871882
break
872883

873884
if packBItems:
874885
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8():
875886
for n in range(instPerPackB):
876-
packINtems[0].append(packBItems.pop(0))
887+
packINtemsB[0].append(packBItems.pop(0))
877888
else:
878889
for j in range(self.states.numReadsIterCoalescedB):
879890
for n in range(instPerPackB):
880-
packINtems[j].append(packBItems.pop(0))
891+
packINtemsB[j].append(packBItems.pop(0))
881892

882893
while packAItems:
883894
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8():
884895
for n in range(instPerPackA):
885-
packINtems[0].append(packAItems.pop(0))
896+
if packAItems:
897+
packINtemsA[0].append(packAItems.pop(0))
898+
else:
899+
break
886900
else:
887901
for j in range(self.states.numReadsIterCoalescedA):
888902
for n in range(instPerPackA):
889-
packINtems[j].append(packAItems.pop(0))
903+
if packAItems:
904+
packINtemsA[j].append(packAItems.pop(0))
905+
else:
906+
break
890907

891-
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
892-
while packMItems:
908+
while packMItems:
909+
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
893910
for j in range(self.states.numReadsIterCoalescedMetadata):
894911
for n in range(ceil(instPerPackM)):
895912
if packMItems:
896-
packINtems[j].append(packMItems.pop(0))
913+
packINtemsM[j].append(packMItems.pop(0))
897914
else:
898915
break
899916

900917
while packBItems:
901918
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8():
902919
for n in range(instPerPackB):
903-
packINtems[0].append(packBItems.pop(0))
920+
if packBItems:
921+
packINtemsB[0].append(packBItems.pop(0))
922+
else:
923+
break
904924
else:
905925
for j in range(self.states.numReadsIterCoalescedB):
906926
for n in range(instPerPackB):
907-
packINtems[j].append(packBItems.pop(0))
927+
if packBItems:
928+
packINtemsB[j].append(packBItems.pop(0))
929+
else:
930+
break
908931

909-
for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB)):
910-
packItems += packINtems.pop(0)
932+
for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)):
933+
packItemsA += packINtemsA.pop(0)
934+
packItemsB += packINtemsB.pop(0)
935+
packItemsM += packINtemsM.pop(0)
911936

912937
# remove s_nop for packing
913938
# we will add s_nop if needed
@@ -1053,7 +1078,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
10531078
numLocalReadShouldSchedule = 0
10541079
# prefetch load for next wave tile along M since we re-use B first.
10551080
tileM: int = kernel["MIWaveTileA"]
1056-
instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItems
1081+
instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItemsA + packItemsM + packItemsB
10571082
localReadItemsThisLoop = sorted(localReadItemsThisLoop, key=lambda o: hasAnyDependency(o, instsToCheck), reverse=True)
10581083

10591084
for lr in localReadItemsThisLoop:
@@ -1241,7 +1266,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
12411266
mfmas = [mfma for mfma in macIterCode.flatitems() if isinstance(mfma, (MFMAInstruction, SMFMAInstruction,))]
12421267
## To support do["MAC"] is False
12431268
mfma = [mfmas[i],] if len(mfmas) > 0 else []
1244-
instsToCheck = mfma + packItems
1269+
instsToCheck = mfma + packItemsA + packItemsM + packItemsB
12451270
numDsInsts = 0
12461271
lastLgkmCnt = -1
12471272
for ds in filter(lambda j: isinstance(j, (DSLoadInstruction, DSStoreInstruction, SWaitCnt)), reversed(prevIterCode.flatitems() + iterCode.flatitems())):
@@ -1271,18 +1296,25 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
12711296
####
12721297
# scheduled pack
12731298
####
1274-
if packItems:
1299+
_instPerPackA = 0
1300+
_instPerPackB = 0
1301+
_instPerPackM = 0
1302+
if packItemsA or packItemsB or packItemsM:
12751303
# how many pack have to be done
12761304
# calculate the data index of this mfma used for A and B
12771305
# if i // kernel["MIWaveTile"][0]==0, mfma will use new A (need to take iu into account)
12781306
# if i % kernel["MIWaveTile"][0]==0, mfma will use new B
1279-
packAIdx += instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
1280-
packBIdx += instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0
1307+
_instPerPackA = instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
1308+
packAIdx += _instPerPackA
1309+
_instPerPackB = instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0
1310+
packBIdx += _instPerPackB
12811311
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
12821312
if kernel["ProblemType"]["Sparse"] == 2:
1283-
packMIdx += instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0
1313+
_instPerPackM = instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0
12841314
else:
1285-
packMIdx += instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
1315+
_instPerPackM = instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
1316+
1317+
packMIdx += _instPerPackM
12861318
# blockWidth < 1, means 0.5 or 0.25 (BF,H,Int8)
12871319
if self.states.archCaps["HasEccHalf"] or not self.states.asmCaps["HasWMMA_V1"]:
12881320
packAIdx = packAIdx if tPA["bpe"] < 4 and (not kernel["UnrollMajorLDSA"] or kernel["ConvertAfterDS"]) else 0
@@ -1298,33 +1330,70 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
12981330
iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u, packMIdx:%u" %(packAIdx,packBIdx,packMIdx))
12991331
else:
13001332
iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u" %(packAIdx,packBIdx))
1301-
# we put 2 pack in each mfma
1302-
for j in range(instPerPackA):
1303-
if packItems:
1304-
iterCode.add(packItems.pop(0))
1333+
1334+
# put the required pack into mfma
1335+
for j in range(_instPerPackA):
1336+
if packItemsA:
1337+
# Skip if the required pack has already been placed in the previous mfma iter.
1338+
if numPackedA >= packAIdx:
1339+
break
1340+
iterCode.add(packItemsA.pop(0))
13051341
curPackIdx += 1
1342+
numPackedA += 1
1343+
13061344
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
1307-
for j in range(ceil(instPerPackM)):
1308-
if packItems:
1309-
iterCode.add(packItems.pop(0))
1345+
for j in range(ceil(_instPerPackM)):
1346+
if packItemsM:
1347+
# Skip if the required pack has already been placed in the previous mfma iter.
1348+
if numPackedM >= packMIdx:
1349+
break
1350+
iterCode.add(packItemsM.pop(0))
13101351
curPackIdx += 1
1311-
for j in range(instPerPackB):
1312-
if packItems:
1313-
iterCode.add(packItems.pop(0))
1314-
curPackIdx += 1
1315-
# since packed register need to wait 2 quad cycle to finish packing
1316-
# we insert pack instruction if we can, or s_nop
1317-
while curPackIdx < numPack+2:
1318-
if packItems:
1319-
iterCode.add(packItems.pop(0))
1320-
curPackIdx += 1
1321-
else:
1322-
iterCode.add(SNop(waitState=1, comment="VALU packing writes to be consumed by matrix instruction"))
1352+
numPackedM += 1
1353+
1354+
for j in range(_instPerPackB):
1355+
if packItemsB:
1356+
# Skip if the required pack has already been placed in the previous mfma iter.
1357+
if numPackedB >= packBIdx:
1358+
break
1359+
iterCode.add(packItemsB.pop(0))
13231360
curPackIdx += 1
1324-
break
1361+
numPackedB += 1
1362+
1363+
# put unnecessary pack into mfma to fulfill the latency
1364+
remainLatency = 2
1365+
if curPackIdx < numPack + 2 :
1366+
# since packed register need to wait 2 quad cycle to finish packing
1367+
# we insert pack instruction if we can, or s_nop
1368+
while remainLatency:
1369+
if packItemsA:
1370+
iterCode.add(packItemsA.pop(0))
1371+
curPackIdx += 1
1372+
numPackedA += 1
1373+
remainLatency -= 1
1374+
elif packItemsM:
1375+
iterCode.add(packItemsM.pop(0))
1376+
curPackIdx += 1
1377+
numPackedM += 1
1378+
remainLatency -= 1
1379+
elif packItemsB:
1380+
iterCode.add(packItemsB.pop(0))
1381+
curPackIdx += 1
1382+
numPackedB += 1
1383+
remainLatency -= 1
1384+
else:
1385+
latency = remainLatency - 1
1386+
iterCode.add(SNop(waitState=latency, comment="VALU packing writes to be consumed by matrix instruction"))
1387+
curPackIdx += 1
1388+
remainLatency -= (latency+1)
1389+
13251390
if i == numMfmaPerIter - 1:
1326-
while packItems:
1327-
iterCode.add(packItems.pop(0))
1391+
while packItemsA:
1392+
iterCode.add(packItemsA.pop(0))
1393+
while packItemsM:
1394+
iterCode.add(packItemsM.pop(0))
1395+
while packItemsB:
1396+
iterCode.add(packItemsB.pop(0))
13281397

13291398
####
13301399
# scheduled mfma dependency

0 commit comments

Comments
 (0)