-
Notifications
You must be signed in to change notification settings - Fork 19
Replace hardcoded CU/SM values with dynamic detection using PyTorch #198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace hardcoded CU/SM values with dynamic detection using PyTorch #198
Conversation
Co-authored-by: mawad-amd <[email protected]>
…ction Co-authored-by: mawad-amd <[email protected]>
… comm Co-authored-by: mawad-amd <[email protected]>
… to top Co-authored-by: mawad-amd <[email protected]>
Co-authored-by: mawad-amd <[email protected]>
Co-authored-by: mawad-amd <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR replaces hardcoded CU (Compute Unit) and SM (Streaming Multiprocessor) values with dynamic detection using PyTorch's CUDA API. Instead of using fixed values like 304 CUs or 288 SMs, the code now auto-detects the GPU's compute capabilities at runtime.
Key changes:
- Updated all benchmark scripts to use
torch.cuda.get_device_properties().multi_processor_count
for dynamic CU detection - Changed default argument values from hardcoded numbers to
None
with auto-detection fallbacks - Added proper math calculations for SM allocation (e.g., next smaller power of 2 for workgroup specialization)
Reviewed Changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
scripts/link_bandwidth.py | Added dynamic CU detection with fallback to hardcoded MI300 default |
examples/benchmark/bench_all_shapes.py | Enhanced CU detection with PyTorch API and partition-based fallbacks |
examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py | Updated to use auto-detected SM values with math-based calculations |
examples/11_gemm_all_scatter_producer_consumer/benchmark.py | Added dynamic SM detection with leftover calculation for comm_sms |
examples/10_gemm_all_scatter_wg_specialization/benchmark.py | Updated to auto-detect total and GEMM SMs using power-of-2 logic |
examples/09_gemm_one_shot_all_reduce/benchmark.py | Replaced hardcoded SM values with dynamic detection |
examples/08_gemm_atomics_all_reduce/benchmark.py | Updated to use auto-detected SM values for all-reduce operations |
examples/07_gemm_all_scatter/benchmark.py | Added dynamic CU detection for persistent GEMM algorithm |
benchmark/examples/benchmark_all_gather_gemm_push.py | Updated to auto-detect SM count with proper args dict updating |
benchmark/examples/benchmark_all_gather_gemm_pull.py | Added dynamic SM detection with fallback to provided values |
Co-authored-by: mawad-amd <[email protected]>
Replaced hardcoded CU values with dynamic detection:
Summary
Eliminated code duplication by hoisting the power-of-2 calculation:
next_pow2
once and reused for bothgemm_sms
andcomm_sms
next_pow2
once and reused for bothgemm_sms
andcomm_sms
Before:
After:
All command-line defaults remain
None
(auto-detected), with manual override still available via CLI arguments.Original prompt
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.