Optimize OSFT factorized linear and gradient projection kernels#75
Optimize OSFT factorized linear and gradient projection kernels#75
Conversation
Two key optimizations that yield ~15% end-to-end training throughput
improvement on Llama-3.1-8B-Instruct (4x H100, bf16, rank_ratio=0.5):
1. Factorized linear (_factorized_linear):
- Flatten input to 2D and use torch.mm instead of batched @ operator
- Replace separate `result_high + result_low` addition with
`result.addmm_(tmp_low, U_low.T)` which fuses the low-rank
matmul and addition into a single cuBLAS call
- Eliminates one kernel launch per OSFT target per forward pass
2. Gradient projection (project_gradient_to_orthogonal_space):
- Replace Gram matrix form `G = V_high^T @ V_high; dV -= dV @ G`
with factored form `dV -= (dV @ V_high^T) @ V_high`
- Avoids materializing the (K, K) Gram matrix (e.g. 4096x4096 for
Llama), replacing it with a small (rank_low, rank_high) intermediate
- Fuse subtraction into matmul via `addmm_(alpha=-1.0)`
- The all-reduce now operates on the smaller (rank_low, rank_high)
tensor instead of (K, K), reducing NCCL communication volume
Also includes transformers v5 compatibility fixes:
- Pass both `torch_dtype` and `dtype` kwargs to from_pretrained
- Handle renamed config attribute (torch_dtype -> dtype)
- Fix dtype validation for FSDP2 mixed precision (params stored
in fp32, cast to bf16 for compute)
- Fix optimizer state validation (always fp32 for stability)
Benchmark (Llama-3.1-8B-Instruct, 4x H100 80GB, bf16, OSFT r=0.5):
Baseline: 12,232 tok/s mean, 12,766 tok/s median
Optimized: 14,080 tok/s mean, 14,385 tok/s median
Speedup: +15.1% mean, +12.7% median
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Warning Rate limit exceeded
Your organization is not enrolled in usage-based pricing. Contact your admin to enable usage-based pricing to continue reviews beyond the rate limit, or try again in 49 minutes and 4 seconds. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughReplaces Gram-matrix V projection with a factored, batched all-reduce projection; consolidates per-module V projections into one flattened all-reduce and local apply; flattens/fuses factorized linear forward to 2D addmm_; adds transformers v5-compatible dtype handling in model load/setup; relaxes dtype validation to allow fp32 optimizer/gradients. Changes
Sequence Diagram(s)sequenceDiagram
participant ModuleA as Module (each rank)
participant ModuleB as ... (other modules)
participant AllReduce as AllReduce / ProcessGroup
participant Local as Local apply
ModuleA->>ModuleA: compute local dV (per-module)
ModuleA->>ModuleA: compute coeff dV_Vt = dV @ V_high^T
ModuleA->>AllReduce: contribute flattened coeffs (batched tensor)
ModuleB->>AllReduce: contribute flattened coeffs
AllReduce->>AllReduce: reduce (sum) flattened coeffs
AllReduce->>ModuleA: reduced coeff slice for ModuleA
ModuleA->>Local: local_dV.addmm_(reduced_coeff, local_V_high, alpha=-1.0)
Local->>ModuleA: updated local gradients (projected)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
src/mini_trainer/train.py (1)
56-64: Relaxed validation logic may mask unexpected dtype issues.The new logic only raises if dtype is not
expected_param_dtypeAND nottorch.float32. This means:
- If
expected_param_dtype=torch.bfloat16, bothbf16andfp32pass silently- If a param unexpectedly becomes
fp16when expectingbf16, it still raises (correct)However, the nested condition structure is a bit confusing. Consider a clearer formulation:
♻️ Suggested clarification
if param.requires_grad and param.dtype != expected_param_dtype: - if param.dtype != torch.float32: - raise ValueError(f"Parameter {name} is not in {expected_param_dtype}, got {param.dtype}") + # FSDP2 MixedPrecisionPolicy may store params in fp32; allow this as valid + allowed_dtypes = {expected_param_dtype, torch.float32} + if param.dtype not in allowed_dtypes: + raise ValueError(f"Parameter {name} has unexpected dtype {param.dtype}, expected one of {allowed_dtypes}") if param.grad is not None and param.grad.dtype != expected_param_dtype: - if param.grad.dtype != torch.float32: - raise ValueError(f"Gradient {name} is not in {expected_param_dtype}, got {param.grad.dtype}") + allowed_dtypes = {expected_param_dtype, torch.float32} + if param.grad.dtype not in allowed_dtypes: + raise ValueError(f"Gradient {name} has unexpected dtype {param.grad.dtype}, expected one of {allowed_dtypes}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/mini_trainer/train.py` around lines 56 - 64, The current nested checks around param.requires_grad, param.dtype, expected_param_dtype and torch.float32 are confusing and can silently allow unintended dtypes; update the validation in train.py so each param (and param.grad) is explicitly allowed only if its dtype equals expected_param_dtype OR equals torch.float32 (to accommodate FSDP storage), otherwise raise a ValueError referencing the parameter name; specifically replace the two nested if-blocks that check param.dtype and param.grad.dtype with clear single-condition checks using (param.dtype != expected_param_dtype and param.dtype != torch.float32) and similarly for (param.grad.dtype != expected_param_dtype and param.grad.dtype != torch.float32) for the same named parameter checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/mini_trainer/train.py`:
- Around line 56-64: The current nested checks around param.requires_grad,
param.dtype, expected_param_dtype and torch.float32 are confusing and can
silently allow unintended dtypes; update the validation in train.py so each
param (and param.grad) is explicitly allowed only if its dtype equals
expected_param_dtype OR equals torch.float32 (to accommodate FSDP storage),
otherwise raise a ValueError referencing the parameter name; specifically
replace the two nested if-blocks that check param.dtype and param.grad.dtype
with clear single-condition checks using (param.dtype != expected_param_dtype
and param.dtype != torch.float32) and similarly for (param.grad.dtype !=
expected_param_dtype and param.grad.dtype != torch.float32) for the same named
parameter checks.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 924be8a6-aaf0-4584-89c4-d5352a75724c
📒 Files selected for processing (3)
src/mini_trainer/osft_utils.pysrc/mini_trainer/setup_model_for_training.pysrc/mini_trainer/train.py
Same pattern as the existing U projection batching: collect all (dV @ V_high^T) coefficients across 224 OSFT targets, concatenate into a single flat tensor, perform one all-reduce, then split back and apply corrections. This reduces V projection from 224 all-reduce launches to 1, cutting NCCL collective launch overhead. Benchmark improvement: +2.6% on top of factored V projection. Total vs baseline: +17.9% mean throughput. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/mini_trainer/osft_utils.py (1)
539-544:⚠️ Potential issue | 🟡 MinorDocstring contradicts the new factored implementation.
The docstring at lines 540-544 states that the V projection "must use the Gram-matrix form" because the factored form "produce[s] column-blocks rather than partial sums — requiring an all-gather instead of an all-reduce." However, the implementation below (lines 591-598) now uses the factored form with
all_reduce.The factored form is correct here: when FSDP2 shards V_high on dim-0,
local_dV @ local_V_high.Tproduces partial sums that all-reduce correctly aggregates. Please update the docstring to reflect the new approach.📝 Suggested docstring update
- V projection must use the Gram-matrix form - dV -= dV @ (V_high^T @ V_high) because FSDP2 shards V_high on dim-0 - (the singular-vector dimension), making the factored form - dV -= (dV @ V_high^T) @ V_high produce column-blocks rather than - partial sums — requiring an all-gather instead of an all-reduce. + V projection uses the factored form dV -= (dV @ V_high^T) @ V_high + with a small (rank_low, rank_high) intermediate. When FSDP2 shards + V_high on dim-0, `dV @ V_high^T` produces partial sums that are + correctly aggregated via all-reduce, then multiplied by local V_high + rows to complete the projection.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/mini_trainer/osft_utils.py` around lines 539 - 544, Update the docstring for the V projection to match the implemented factored form: replace the claim that the Gram-matrix form is required and that the factored form would produce column-blocks requiring an all-gather; instead state that because FSDP2 shards V_high on dim-0, the local computation (local_dV @ local_V_high.T) produces partial sums which are correctly aggregated via all_reduce, so the code uses the factored form with all_reduce (referencing variables dV, V_high, local_dV, local_V_high and the V projection implementation).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 2148-2172: The non-distributed else branch under the "if
v_flat_parts:" check is dead code because non-distributed paths return earlier;
remove the unreachable block (lines that iterate over v_work using
v_flat_parts.pop and dV.copy_) and either replace it with an assertion like
"assert is_distributed" or rely on the surrounding logic to guarantee
is_distributed is True; update any tests or callers if they depended on that
branch. Ensure you only modify the branch inside the v_flat_parts handling
(references: v_flat_parts, is_distributed, v_work, local_V_high, local_dV, dV,
coeff_shape) to eliminate the redundant code path.
---
Outside diff comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 539-544: Update the docstring for the V projection to match the
implemented factored form: replace the claim that the Gram-matrix form is
required and that the factored form would produce column-blocks requiring an
all-gather; instead state that because FSDP2 shards V_high on dim-0, the local
computation (local_dV @ local_V_high.T) produces partial sums which are
correctly aggregated via all_reduce, so the code uses the factored form with
all_reduce (referencing variables dV, V_high, local_dV, local_V_high and the V
projection implementation).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 175023e5-92a8-4b20-9a19-4258601a5b86
📒 Files selected for processing (1)
src/mini_trainer/osft_utils.py
|
@claude fix these CI issues |
|
Claude finished @RobotSail's task in 5m 53s —— View job Fixing CI Issues
Note on remaining test failuresThe 4 |
- Fix ruff format: add spaces around slice in batched_v[offset : offset + numel] - Update stale docstring: V projection now correctly documents the factored form with all-reduce instead of the old Gram-matrix claim - Remove dead code: non-distributed else branch in v_flat_parts handling is unreachable since non-distributed paths return early at line 2074 - Clarify dtype validation: use allowed_dtypes set instead of nested conditionals for FSDP2 mixed precision param/grad checking Co-authored-by: Oleg Silkin <RobotSail@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@claude-code it looks like there are still some merge conflicts with the model setup script. Maybe you need to rebase this |
|
Claude finished @RobotSail's task in 2m 52s —— View job Resolving Merge Conflicts
Conflict Resolution DetailsAll three conflicts in
Merged main into the branch via | Branch |
Resolves merge conflicts in setup_model_for_training.py by combining: - trust_remote_code support from main (#78, #81) - transformers v5 dtype compatibility from this branch All three conflicts resolved by keeping both sets of changes: 1. get_model_save_dtype: trust_remote_code param + dtype fallback 2. base_model_args: dtype key + trust_remote_code key 3. save_dtype assignment: trust_remote_code passthrough + v5 dtype handling Co-authored-by: Oleg Silkin <RobotSail@users.noreply.github.com>
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Summary
addmm_call, flatten to 2D fortorch.mm(K, K)Gram matrix with factored(rank_low, rank_high)intermediate, fuse subtraction viaaddmm_torch_dtype→dtyperename, fix mixed-precision dtype validationBenchmark
Llama-3.1-8B-Instruct, 4x H100 80GB, bf16, OSFT
rank_ratio=0.5,batch_size=32:Dataset: 1,000 samples from UltraChat-200k, tokenized with Llama-3.1 chat template, median 1,118 tokens.
Details
1. Factorized linear (
_factorized_linear) — +11%Before: 4 separate matmuls + element-wise addition
After: flatten to 2D, 3
torch.mm+ 1addmm_2. Gradient projection (
project_gradient_to_orthogonal_space) — +4%Before: materializes
(K, K)Gram matrix (e.g. 4096×4096 for Llama)After: factored form with
(rank_low, rank_high)intermediate3. Batched V projection all-reduces — +2.6%
Before: 224 separate
dist.all_reduce()calls for V projection coefficientsAfter: single batched
dist.all_reduce()of concatenated coefficients (same pattern as existing U projection batching from PR #72)4. Transformers v5 compatibility
torch_dtypeanddtypekwargs tofrom_pretrainedtorch_dtype→dtype)Test plan
🤖 Generated with Claude Code
Summary by CodeRabbit