-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix assertion in ScanLowering for num_ctas>1 #5680
base: main
Are you sure you want to change the base?
Conversation
|
Added the test cases, they pass if the scan is within a CTA and fail if it is across multiple CTAs. I didn't find any logic which performs accumulation across CTAs in the code. I could try to look into the logic and test coverage for cross-CTA scans in a different PR, but I think fixing the assertion for within-CTA scans can be independent from that. How should I proceed? |
For now it should happen only within a CTA. Please do not work on anything across CTAs |
This reverts commit 33ca052.
This reverts commit 44ddc3e.
Assume the clusterCTAId along the scan axis (cctaIdAxis) is ==0, raise runtime assertion otherwise. Combine the clusterCTAId across the scan axis (cctaIdParallel) into the flatIdParallel and compute numParallelLane per CGA instead of per CTA.
Test BlockedLayout for - thread_size=4, num_warps=4, num_ctas=1 - thread_size=4, num_warps=1, num_ctas=4: CTASplitNum=[1,1] - thread_size=4, num_warps=1, num_ctas=4: CTASplitNum=CTAsPerCGA - thread_size=1, num_warps=4, num_ctas=4: CTASplitNum=[1,1] - thread_size=1, num_warps=4, num_ctas=4: CTASplitNum=CTAsPerCGA
…fix-scanlowering-cga
The initial fix was only superficial, so I implemented within-CTA scan from scratch. I changed testing from end-to-end tests to layout tests to make sure all edge cases are caught. Pytest output for I'm not sure if the tests are too extensive now, maybe we could just test the combinations with Is there anything else I can do to get this merged? |
Enable within-CTA scans if
num_ctas>1
andcluster_dims[axis]==1
on Hopper or later.Assume the
clusterCTAId
along the scan axis iscctaIdAxis==0
, raise runtime assertion otherwise. Combine theclusterCTAId
across the scan axis (cctaIdParallel
) into theflatIdParallel
and computenumParallelLane
per CGA instead of per CTA.This fixes assertions when
num_ctas > 1
:triton/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Line 150 in cea35da
triton/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Line 254 in cea35da
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsSelect one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)