diff --git a/clients/gtest/matmul_gtest.yaml b/clients/gtest/matmul_gtest.yaml index db6d04697c..2c41e070ac 100755 --- a/clients/gtest/matmul_gtest.yaml +++ b/clients/gtest/matmul_gtest.yaml @@ -1593,7 +1593,7 @@ Tests: algo_method: [0,1] transA_transB: *transA_transB_range alpha: 1 - beta: 0 + beta: [0,1] requested_solution_num: -1 unit_check: 1 @@ -1607,7 +1607,7 @@ Tests: algo_method: [0,1] transA_transB: *transA_transB_range alpha: 1 - beta: 0 + beta: [0,1] requested_solution_num: -1 unit_check: 1 gpu_arch: '94[0-2]' diff --git a/tensilelite/Tensile/Components/GlobalWriteBatch.py b/tensilelite/Tensile/Components/GlobalWriteBatch.py index 55be6fc528..748a827393 100644 --- a/tensilelite/Tensile/Components/GlobalWriteBatch.py +++ b/tensilelite/Tensile/Components/GlobalWriteBatch.py @@ -2011,7 +2011,7 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV # Generate single f32 code if edge is detected. isPK = False if ((vi + 1) == self.gwvw) and ((self.gwvw % 2) == 1): - if self.parentWriter.states.archCaps["NoSDWA"]: #cm review + if self.parentWriter.states.archCaps["NoSDWA"]: sb = 0 if self.gwvw == 1 else 1 module.add(VCvtFP8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[0,sb]))) else: @@ -2022,11 +2022,13 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV continue else: isPK = True - if self.parentWriter.states.archCaps["NoSDWA"]: #cm review - sb = 0 if vi ==0 else 1 + if self.parentWriter.states.archCaps["NoSDWA"]: + # Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8 + sb = 0 if vi%4 == 0 else 1 module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[sb]))) else: - sb = SelectBit.WORD_0 if vi == 0 else SelectBit.WORD_1 + # Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8 + sb = SelectBit.WORD_0 if vi%4 == 0 else SelectBit.WORD_1 module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), sdwa=SDWAModifiers(src0_sel=sb))) module.add(SNop(waitState=0)) if kernel["ProblemType"]["ComputeDataType"].isSingle(): @@ -2040,7 +2042,7 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV # Generate single f32 code if edge is detected. isPK = False if ((vi + 1) == self.gwvw) and ((self.gwvw % 2) == 1): - if self.parentWriter.states.archCaps["NoSDWA"]: #cm review + if self.parentWriter.states.archCaps["NoSDWA"]: sb = 0 if self.gwvw == 1 else 1 module.add(VCvtFP8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[0,sb]))) else: @@ -2051,11 +2053,13 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV continue else: isPK = True - if self.parentWriter.states.archCaps["NoSDWA"]: #cm review - sb = 0 if vi ==0 else 1 + if self.parentWriter.states.archCaps["NoSDWA"]: + # Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8 + sb = 0 if vi%4 == 0 else 1 module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[sb]))) else: - sb = SelectBit.WORD_0 if vi == 0 else SelectBit.WORD_1 + # Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8 + sb = SelectBit.WORD_0 if vi%4 == 0 else SelectBit.WORD_1 module.add(VCvtPkBF8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), sdwa=SDWAModifiers(src0_sel=sb))) module.add(SNop(waitState=0)) if kernel["ProblemType"]["ComputeDataType"].isSingle():