Skip to content

Optimize OSFT factorized linear and gradient projection kernels#75

Open
RobotSail wants to merge 4 commits intomainfrom
optimize-osft-kernels
Open

Optimize OSFT factorized linear and gradient projection kernels#75
RobotSail wants to merge 4 commits intomainfrom
optimize-osft-kernels

Conversation

@RobotSail
Copy link
Copy Markdown
Collaborator

@RobotSail RobotSail commented Mar 12, 2026

Summary

  • Factorized linear forward: fuse low-rank matmul + addition into single addmm_ call, flatten to 2D for torch.mm
  • Gradient projection V path: replace (K, K) Gram matrix with factored (rank_low, rank_high) intermediate, fuse subtraction via addmm_
  • Batched V projection all-reduces: collect all V projection coefficients across 224 OSFT targets into a single NCCL all-reduce (same pattern as existing U projection batching)
  • Transformers v5 compat: handle torch_dtypedtype rename, fix mixed-precision dtype validation

Benchmark

Llama-3.1-8B-Instruct, 4x H100 80GB, bf16, OSFT rank_ratio=0.5, batch_size=32:

Metric Baseline Optimized Speedup
Mean tok/s 12,232 14,417 +17.9%
Median tok/s 12,766 14,284 +11.9%
Peak VRAM 41.7 GB 41.7 GB same
Loss @ step 20 0.924 0.925 equivalent

Dataset: 1,000 samples from UltraChat-200k, tokenized with Llama-3.1 chat template, median 1,118 tokens.

Details

1. Factorized linear (_factorized_linear) — +11%

Before: 4 separate matmuls + element-wise addition

result_high = (x @ V_high.T) * S_high @ U_high.T
result_low  = (x @ V_low.T)  * S_low  @ U_low.T
result = result_high + result_low  # separate addition kernel

After: flatten to 2D, 3 torch.mm + 1 addmm_

x_2d = x.reshape(-1, K)
result = torch.mm(tmp_high, U_high.T)
result.addmm_(tmp_low, U_low.T)  # fused matmul + add
result = result.reshape(*orig_shape[:-1], N)

2. Gradient projection (project_gradient_to_orthogonal_space) — +4%

Before: materializes (K, K) Gram matrix (e.g. 4096×4096 for Llama)

G = V_high.T @ V_high          # (K, K) — large
dV -= dV @ G

After: factored form with (rank_low, rank_high) intermediate

dV_Vt = dV @ V_high.T          # (rank_low, rank_high) — small
dV.addmm_(dV_Vt, V_high, alpha=-1.0)

3. Batched V projection all-reduces — +2.6%

Before: 224 separate dist.all_reduce() calls for V projection coefficients
After: single batched dist.all_reduce() of concatenated coefficients (same pattern as existing U projection batching from PR #72)

4. Transformers v5 compatibility

  • Pass both torch_dtype and dtype kwargs to from_pretrained
  • Handle renamed config attribute (torch_dtypedtype)
  • Fix dtype validation for FSDP2 mixed precision (params stored in fp32, cast to bf16 for compute)
  • Fix optimizer state validation (always fp32 for numerical stability)

Test plan

  • OSFT training runs to completion with bf16 mixed precision on 4x H100
  • Loss convergence matches baseline (0.924 vs 0.925)
  • No VRAM increase
  • Batched V projection produces identical results to per-module path
  • Run existing regression tests
  • Verify with different rank ratios (0.25, 0.75)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Improvements
    • Enhanced compatibility with transformers v5 through improved dtype configuration handling.
    • Optimized gradient projection operations for better memory efficiency via batched computation.
    • Increased validation flexibility to support multiple dtype configurations during training.

Two key optimizations that yield ~15% end-to-end training throughput
improvement on Llama-3.1-8B-Instruct (4x H100, bf16, rank_ratio=0.5):

1. Factorized linear (_factorized_linear):
   - Flatten input to 2D and use torch.mm instead of batched @ operator
   - Replace separate `result_high + result_low` addition with
     `result.addmm_(tmp_low, U_low.T)` which fuses the low-rank
     matmul and addition into a single cuBLAS call
   - Eliminates one kernel launch per OSFT target per forward pass

2. Gradient projection (project_gradient_to_orthogonal_space):
   - Replace Gram matrix form `G = V_high^T @ V_high; dV -= dV @ G`
     with factored form `dV -= (dV @ V_high^T) @ V_high`
   - Avoids materializing the (K, K) Gram matrix (e.g. 4096x4096 for
     Llama), replacing it with a small (rank_low, rank_high) intermediate
   - Fuse subtraction into matmul via `addmm_(alpha=-1.0)`
   - The all-reduce now operates on the smaller (rank_low, rank_high)
     tensor instead of (K, K), reducing NCCL communication volume

Also includes transformers v5 compatibility fixes:
   - Pass both `torch_dtype` and `dtype` kwargs to from_pretrained
   - Handle renamed config attribute (torch_dtype -> dtype)
   - Fix dtype validation for FSDP2 mixed precision (params stored
     in fp32, cast to bf16 for compute)
   - Fix optimizer state validation (always fp32 for stability)

Benchmark (Llama-3.1-8B-Instruct, 4x H100 80GB, bf16, OSFT r=0.5):
  Baseline:  12,232 tok/s mean, 12,766 tok/s median
  Optimized: 14,080 tok/s mean, 14,385 tok/s median
  Speedup:   +15.1% mean, +12.7% median

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 12, 2026

Warning

Rate limit exceeded

@claude[bot] has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 49 minutes and 4 seconds before requesting another review.

Your organization is not enrolled in usage-based pricing. Contact your admin to enable usage-based pricing to continue reviews beyond the rate limit, or try again in 49 minutes and 4 seconds.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1fddea06-8817-41cc-b4c9-d7c216df3eed

📥 Commits

Reviewing files that changed from the base of the PR and between bb967cd and fb2c230.

📒 Files selected for processing (2)
  • src/mini_trainer/setup_model_for_training.py
  • src/mini_trainer/train.py
📝 Walkthrough

Walkthrough

Replaces Gram-matrix V projection with a factored, batched all-reduce projection; consolidates per-module V projections into one flattened all-reduce and local apply; flattens/fuses factorized linear forward to 2D addmm_; adds transformers v5-compatible dtype handling in model load/setup; relaxes dtype validation to allow fp32 optimizer/gradients.

Changes

Cohort / File(s) Summary
OSFT Core Changes
src/mini_trainer/osft_utils.py
Replaced Gram-matrix V projection with factored form (compute local dV_Vt = dV @ V_high^T, batched all-reduce, slice results, then local_dV.addmm_(dV_Vt, V_high, alpha=-1)). Batched distributed V projections consolidated; factorized_linear now flattens inputs to 2D, computes high/low-rank paths with mm/addmm, handles bias, and reshapes output. Added dtype/backwards-compatible loader tweaks and docstrings.
Model setup / dtype handling
src/mini_trainer/setup_model_for_training.py
Added get_model_save_dtype fallback to original_config.dtype (string mapping supported); added dtype to base_model_args; unify save-dtype assignment to model.config.dtype when present else model.config.torch_dtype; improved error message when save dtype missing.
Training dtype validation
src/mini_trainer/train.py
validate_training_state now accepts parameter/gradient dtypes equal to expected_param_dtype or torch.float32; take_gradient_step enforces optimizer-state dtype validation against torch.float32 unconditionally.

Sequence Diagram(s)

sequenceDiagram
    participant ModuleA as Module (each rank)
    participant ModuleB as ... (other modules)
    participant AllReduce as AllReduce / ProcessGroup
    participant Local as Local apply

    ModuleA->>ModuleA: compute local dV (per-module)
    ModuleA->>ModuleA: compute coeff dV_Vt = dV @ V_high^T
    ModuleA->>AllReduce: contribute flattened coeffs (batched tensor)
    ModuleB->>AllReduce: contribute flattened coeffs
    AllReduce->>AllReduce: reduce (sum) flattened coeffs
    AllReduce->>ModuleA: reduced coeff slice for ModuleA
    ModuleA->>Local: local_dV.addmm_(reduced_coeff, local_V_high, alpha=-1.0)
    Local->>ModuleA: updated local gradients (projected)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • NikhilNayak-debug

Poem

🐇 A rabbit taps the matmul key,

Flattens tensors, sets them free.
Coeffs all-gather, slices land true,
Gradients lean where they must do.
Hop—efficient steps in code anew. 🎉

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main focus of the PR, which implements optimizations to OSFT factorized linear forward and gradient projection kernels through refactored computation patterns and batched distributed operations.
Docstring Coverage ✅ Passed Docstring coverage is 90.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch optimize-osft-kernels

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
src/mini_trainer/train.py (1)

56-64: Relaxed validation logic may mask unexpected dtype issues.

The new logic only raises if dtype is not expected_param_dtype AND not torch.float32. This means:

  • If expected_param_dtype=torch.bfloat16, both bf16 and fp32 pass silently
  • If a param unexpectedly becomes fp16 when expecting bf16, it still raises (correct)

However, the nested condition structure is a bit confusing. Consider a clearer formulation:

♻️ Suggested clarification
         if param.requires_grad and param.dtype != expected_param_dtype:
-            if param.dtype != torch.float32:
-                raise ValueError(f"Parameter {name} is not in {expected_param_dtype}, got {param.dtype}")
+            # FSDP2 MixedPrecisionPolicy may store params in fp32; allow this as valid
+            allowed_dtypes = {expected_param_dtype, torch.float32}
+            if param.dtype not in allowed_dtypes:
+                raise ValueError(f"Parameter {name} has unexpected dtype {param.dtype}, expected one of {allowed_dtypes}")
         if param.grad is not None and param.grad.dtype != expected_param_dtype:
-            if param.grad.dtype != torch.float32:
-                raise ValueError(f"Gradient {name} is not in {expected_param_dtype}, got {param.grad.dtype}")
+            allowed_dtypes = {expected_param_dtype, torch.float32}
+            if param.grad.dtype not in allowed_dtypes:
+                raise ValueError(f"Gradient {name} has unexpected dtype {param.grad.dtype}, expected one of {allowed_dtypes}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mini_trainer/train.py` around lines 56 - 64, The current nested checks
around param.requires_grad, param.dtype, expected_param_dtype and torch.float32
are confusing and can silently allow unintended dtypes; update the validation in
train.py so each param (and param.grad) is explicitly allowed only if its dtype
equals expected_param_dtype OR equals torch.float32 (to accommodate FSDP
storage), otherwise raise a ValueError referencing the parameter name;
specifically replace the two nested if-blocks that check param.dtype and
param.grad.dtype with clear single-condition checks using (param.dtype !=
expected_param_dtype and param.dtype != torch.float32) and similarly for
(param.grad.dtype != expected_param_dtype and param.grad.dtype != torch.float32)
for the same named parameter checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/mini_trainer/train.py`:
- Around line 56-64: The current nested checks around param.requires_grad,
param.dtype, expected_param_dtype and torch.float32 are confusing and can
silently allow unintended dtypes; update the validation in train.py so each
param (and param.grad) is explicitly allowed only if its dtype equals
expected_param_dtype OR equals torch.float32 (to accommodate FSDP storage),
otherwise raise a ValueError referencing the parameter name; specifically
replace the two nested if-blocks that check param.dtype and param.grad.dtype
with clear single-condition checks using (param.dtype != expected_param_dtype
and param.dtype != torch.float32) and similarly for (param.grad.dtype !=
expected_param_dtype and param.grad.dtype != torch.float32) for the same named
parameter checks.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 924be8a6-aaf0-4584-89c4-d5352a75724c

📥 Commits

Reviewing files that changed from the base of the PR and between 3300833 and 1461614.

📒 Files selected for processing (3)
  • src/mini_trainer/osft_utils.py
  • src/mini_trainer/setup_model_for_training.py
  • src/mini_trainer/train.py

Same pattern as the existing U projection batching: collect all
(dV @ V_high^T) coefficients across 224 OSFT targets, concatenate
into a single flat tensor, perform one all-reduce, then split back
and apply corrections.

This reduces V projection from 224 all-reduce launches to 1,
cutting NCCL collective launch overhead.

Benchmark improvement: +2.6% on top of factored V projection.
Total vs baseline: +17.9% mean throughput.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/mini_trainer/osft_utils.py (1)

539-544: ⚠️ Potential issue | 🟡 Minor

Docstring contradicts the new factored implementation.

The docstring at lines 540-544 states that the V projection "must use the Gram-matrix form" because the factored form "produce[s] column-blocks rather than partial sums — requiring an all-gather instead of an all-reduce." However, the implementation below (lines 591-598) now uses the factored form with all_reduce.

The factored form is correct here: when FSDP2 shards V_high on dim-0, local_dV @ local_V_high.T produces partial sums that all-reduce correctly aggregates. Please update the docstring to reflect the new approach.

📝 Suggested docstring update
-    V projection must use the Gram-matrix form
-    dV -= dV @ (V_high^T @ V_high)  because FSDP2 shards V_high on dim-0
-    (the singular-vector dimension), making the factored form
-    dV -= (dV @ V_high^T) @ V_high  produce column-blocks rather than
-    partial sums — requiring an all-gather instead of an all-reduce.
+    V projection uses the factored form  dV -= (dV @ V_high^T) @ V_high
+    with a small (rank_low, rank_high) intermediate.  When FSDP2 shards
+    V_high on dim-0, `dV @ V_high^T` produces partial sums that are
+    correctly aggregated via all-reduce, then multiplied by local V_high
+    rows to complete the projection.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mini_trainer/osft_utils.py` around lines 539 - 544, Update the docstring
for the V projection to match the implemented factored form: replace the claim
that the Gram-matrix form is required and that the factored form would produce
column-blocks requiring an all-gather; instead state that because FSDP2 shards
V_high on dim-0, the local computation (local_dV @ local_V_high.T) produces
partial sums which are correctly aggregated via all_reduce, so the code uses the
factored form with all_reduce (referencing variables dV, V_high, local_dV,
local_V_high and the V projection implementation).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 2148-2172: The non-distributed else branch under the "if
v_flat_parts:" check is dead code because non-distributed paths return earlier;
remove the unreachable block (lines that iterate over v_work using
v_flat_parts.pop and dV.copy_) and either replace it with an assertion like
"assert is_distributed" or rely on the surrounding logic to guarantee
is_distributed is True; update any tests or callers if they depended on that
branch. Ensure you only modify the branch inside the v_flat_parts handling
(references: v_flat_parts, is_distributed, v_work, local_V_high, local_dV, dV,
coeff_shape) to eliminate the redundant code path.

---

Outside diff comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 539-544: Update the docstring for the V projection to match the
implemented factored form: replace the claim that the Gram-matrix form is
required and that the factored form would produce column-blocks requiring an
all-gather; instead state that because FSDP2 shards V_high on dim-0, the local
computation (local_dV @ local_V_high.T) produces partial sums which are
correctly aggregated via all_reduce, so the code uses the factored form with
all_reduce (referencing variables dV, V_high, local_dV, local_V_high and the V
projection implementation).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 175023e5-92a8-4b20-9a19-4258601a5b86

📥 Commits

Reviewing files that changed from the base of the PR and between 1461614 and 91bc1eb.

📒 Files selected for processing (1)
  • src/mini_trainer/osft_utils.py

Comment thread src/mini_trainer/osft_utils.py
@RobotSail
Copy link
Copy Markdown
Collaborator Author

@claude fix these CI issues

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 14, 2026

Claude finished @RobotSail's task in 5m 53s —— View job


Fixing CI Issues

  • Check CI status and identify failures
  • Diagnose failures:
    • Formatting (from this PR): src/mini_trainer/osft_utils.py — missing spaces around slice operator
    • Tests (pre-existing): 4 test_data_loader.py failures (KeyError: 'num_loss_counted_tokens') — not introduced by this PR (no data_loader files were changed)
  • Fix ruff formatting: batched_v[offset:offset + numel]batched_v[offset : offset + numel]
  • Fix stale docstring: V projection docstring now correctly documents the factored form with all-reduce (was still describing the old Gram-matrix approach)
  • Remove dead code: non-distributed else branch in v_flat_parts handling was unreachable (non-distributed paths return early at line 2074)
  • Clarify dtype validation in train.py: replaced nested if conditionals with allowed_param_dtypes set for FSDP2 mixed precision checking
  • Push fixes → bb967cd

Note on remaining test failures

The 4 test_data_loader.py::TestJsonlDataset failures (KeyError: 'num_loss_counted_tokens') are pre-existing on the base branch — this PR does not modify test_data_loader.py or data_loader.py. These will need to be addressed separately.
| Branch

- Fix ruff format: add spaces around slice in batched_v[offset : offset + numel]
- Update stale docstring: V projection now correctly documents the factored
  form with all-reduce instead of the old Gram-matrix claim
- Remove dead code: non-distributed else branch in v_flat_parts handling is
  unreachable since non-distributed paths return early at line 2074
- Clarify dtype validation: use allowed_dtypes set instead of nested
  conditionals for FSDP2 mixed precision param/grad checking

Co-authored-by: Oleg Silkin <RobotSail@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@RobotSail
Copy link
Copy Markdown
Collaborator Author

@claude-code it looks like there are still some merge conflicts with the model setup script. Maybe you need to rebase this

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 14, 2026

Claude finished @RobotSail's task in 2m 52s —— View job


Resolving Merge Conflicts

  • Investigate merge conflicts with origin/main
    • Found conflict in src/mini_trainer/setup_model_for_training.py
    • src/mini_trainer/train.py auto-merged cleanly
    • Main had 15 new commits since this branch diverged (including #81 removing osft_memory_efficient_init, #78 fixing trust_remote_code, #77 fixing Gemma VLM detection)
  • Resolve conflicts in setup_model_for_training.py
  • Push resolved changes

Conflict Resolution Details

All three conflicts in setup_model_for_training.py were resolved by combining both sets of changes:

  1. get_model_save_dtype (line 693): Added trust_remote_code from main + kept dtype fallback from this PR
  2. base_model_args (line 946-948): Kept dtype key for transformers v5 + added trust_remote_code from main
  3. Save dtype assignment (line 1173-1180): Passed trust_remote_code to get_model_save_dtype + kept v5 dtype config handling

Merged main into the branch via fb2c230.


| Branch

Resolves merge conflicts in setup_model_for_training.py by combining:
- trust_remote_code support from main (#78, #81)
- transformers v5 dtype compatibility from this branch

All three conflicts resolved by keeping both sets of changes:
1. get_model_save_dtype: trust_remote_code param + dtype fallback
2. base_model_args: dtype key + trust_remote_code key
3. save_dtype assignment: trust_remote_code passthrough + v5 dtype handling

Co-authored-by: Oleg Silkin <RobotSail@users.noreply.github.com>
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 14, 2026

Codecov Report

❌ Patch coverage is 83.33333% with 8 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/mini_trainer/train.py 20.00% 4 Missing ⚠️
src/mini_trainer/osft_utils.py 94.44% 2 Missing ⚠️
src/mini_trainer/setup_model_for_training.py 71.42% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant