-
Notifications
You must be signed in to change notification settings - Fork 74
prioritze using smem only when cached inputs are persistent #5764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
!test |
|
Review updated until commit a0797d3 Description
|
| Relevant files | |||
|---|---|---|---|
| Enhancement |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Logic correctness
cached_inputs_are_persistent uses std::all_of to check if all buffers are fusion inputs. This might be too restrictive - there could be cases where some buffers are fusion inputs and others are intermediate tensors that should still be treated as persistent. Consider if this condition accurately captures all scenarios where shared memory should be preferred. |
Test failures
-
(High, 5)
CUDA misaligned address in multidevice Transformer backward test (test_transformer_backward[SEQUENCE_PARALLEL])Test Name A100 GB200 GB200 (dist.) H100 H100 (dist.) Source tests.python.multidevice.test_transformer.test_transformer_backward[SEQUENCE_PARALLEL] ❌ ❌ ❌ ❌ ❌ -
(High, 4)
CUBLAS_STATUS_EXECUTION_FAILED in nvFuser transformer_backward (test_transformer_backward[TENSOR_PARALLEL]) across multiple runnersTest Name GB200 GB200 (dist.) H100 H100 (dist.) Source tests.python.multidevice.test_transformer.test_transformer_backward[TENSOR_PARALLEL] ❌ ❌ ❌ ❌ -
(High, 1)
CUDA misaligned address in Flash Attention backward during multidevice transformer (A100, test_transformer_backward)Test Name A100 Source tests.python.multidevice.test_transformer.test_transformer_backward[TENSOR_PARALLEL] ❌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
1631b53 to
8cfa421
Compare
|
@greptileai review this PR |
Greptile SummaryOptimizes persistent buffer allocation by preferring registers over shared memory when cached inputs aren't persistent buffers, avoiding the inefficient Key Changes:
Critical Issue Found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Scheduler as Inner Outer Scheduler
participant Utils as getPersistentBufferStorageParams
participant Buffers as Persistent Buffers
participant Memory as Memory Allocation
Scheduler->>Utils: Request buffer storage params
Utils->>Utils: Determine if cached_inputs_are_persistent
alt cached_inputs_are_persistent=true (buffers are fusion inputs or project_to_input)
Note over Utils: OLD & NEW: Use smem path
Utils->>Buffers: Allocate to smem_buffer_size_bit
Buffers-->>Memory: gmem → (CpAsync/TMA) → smem → regs
else cached_inputs_are_persistent=false (NEW BEHAVIOR)
Note over Utils: NEW: Use registers path
Utils->>Buffers: Allocate to regs_buffer_size_bit
Buffers-->>Memory: gmem → regs (cached) → regs (computation)
Note over Memory: Avoids redundant regs→smem→regs path
end
alt Resources fit
Utils->>Scheduler: Return with has_enough_regs_and_smem=true
else Resources don't fit
Note over Utils: ISSUE: Fallback assumes buffers in smem
Utils->>Utils: Move buffers from smem to regs
Note over Utils: BUG: When cached_inputs_are_persistent=false,<br/>smem_buffer_size_bit=0, subtraction produces negatives
Utils->>Scheduler: Return allocation result
end
|
|
!test |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
csrc/scheduler/normalization_inner_outer_utils.cpp, line 269-284 (link)logic: Fallback logic assumes all buffers start in shared memory (line 275 comment), but when
cached_inputs_are_persistent=false(line 246), buffers are placed in registers andsmem_buffer_size_bit=0. Line 284 subtracts fromsmem_buffer_size_bit, which will produce negative values. Add early return or skip this fallback path when!cached_inputs_are_persistent:Is there a hybrid allocation strategy needed when cached_inputs_are_persistent=false but resources are insufficient?
2 files reviewed, 1 comment
Behavior before this PR:
Inner outer persistent scheduler always prioritize using shared memory to store persistent buffers.
Issue:
When cached inputs are not persistent buffers, the data flow is
gmem (input) --> regs (cached input) --> smem (persistent buffer) --> regs (for computations). The path ofregs-> smem -> regsis redundant.It only makes sense to use smem persistent when the cached inputs are persistent buffer in which case the data flow is:
gmem -- (TMA or CpAsync)--> smem (persistent buffer) --> regs (for computations)After this PR
Only prioritize using shared memory to store persistent buffers when cached inputs are persistent buffer.
The performance of the newly added test is increased from 50.5% SOL to 61.6% SOL on GB200.