Skip to content

Support LoRA for MOE Megatron SequentialMLP#979

Open
jenchen13 wants to merge 7 commits intomainfrom
jennifchen/moe_lora
Open

Support LoRA for MOE Megatron SequentialMLP#979
jenchen13 wants to merge 7 commits intomainfrom
jennifchen/moe_lora

Conversation

@jenchen13
Copy link
Contributor

@jenchen13 jenchen13 commented Mar 5, 2026

What does this PR do?

Type of change: New Feature

  • Add LoRA support for Megatron SequentialMLP in MOE local experts
  • LoRA adapters are on a per expert level, with shared lora_down for local experts in a layer, and individual lora_up per local expert. This is to accommodate SVDQuant kernel which merges lora_down and quantize into one kernel.
Screenshot 2026-03-05 at 9 00 00 AM

Usage

# Add a code snippet demonstrating how to use this

Testing

Tested adding adapters to a MOE layer, and also gradient flow in MOE.
TODO test sharded state dict

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, using torch.load(..., weights_only=True), avoiding pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other source, did you follow IP policy in CONTRIBUTING.md?: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • MoE SequentialMLP LoRA: shared down-projection with per-expert up-projections.
    • LoRA support for Transformer Engine parallel linear layers, including quantized TE variants and registration for discovery.
    • Predefined LoRA configuration presets for dense and MoE setups.
  • Tests

    • Expanded tests for MoE SequentialMLP LoRA structure and gradients, TE LoRA types, forward/save-restore, and quantization interactions.
  • Behavior

    • Sequence-parallel now enables automatically for MoE when tensor-model-parallelism is active.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
@jenchen13 jenchen13 requested a review from a team as a code owner March 5, 2026 17:03
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 5, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ca40c7b8-ee59-45c0-9af3-e7acaf894323

📥 Commits

Reviewing files that changed from the base of the PR and between b80c82d and 962df9b.

📒 Files selected for processing (1)
  • modelopt/torch/peft/lora/config.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/peft/lora/config.py

📝 Walkthrough

Walkthrough

Adds Megatron MoE LoRA adapters (SequentialMLP with shared down / per-expert up), TE Column/Row Parallel LoRA and quantized variants, sharded-state handling, new Megatron LoRA configs, and tests integrating MoE, TE, and quantization.

Changes

Cohort / File(s) Summary
Megatron LoRA & TE adapters
modelopt/torch/peft/lora/plugins/megatron.py
Adds _LoRAMegatronSequentialMLP (shared lora_down, per-expert lora_up), TE adapters (_LoRATEColumnParallelLinear, _LoRATERowParallelLinear), quantized TE adapter variants, TE availability detection, sharded_state_dict extensions, and registry registrations.
LoRA config definitions
modelopt/torch/peft/lora/config.py
Introduces DENSE_LORA_CFG, MOE_LORA_CFG, MOE_LORA_RANDOM_INIT_CFG, LORA_CFG_CHOICES, pattern-based targeting, and exports via __all__.
Model builder (test utilities)
tests/_test_utils/torch/megatron/models.py
Threads num_moe_experts, use_te, and moe_grouped_gemm into model/spec construction and computes use_sp (sequence_parallel) when MoE + TP > 1.
MoE / TE / quantization tests
tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py
Adds MoE SequentialMLP LoRA configs (shared down / per-expert up) and selective patterns; extends model provider signature; adds tests for MoE SequentialMLP LoRA structure and gradients, TE LoRA module types, forward/enable-disable/save-restore, and quantization interactions.

Sequence Diagram(s)

sequenceDiagram
    participant Input as Input\n(permuted_local_hidden_states,\n tokens_per_expert, permuted_probs)
    participant Base as Base\nSequentialMLP
    participant SharedDown as Shared\nlora_down
    participant PerExpert as Per-Expert\nlora_up (ModuleList)
    participant Aggregator as Aggregator

    Input->>Base: compute base output
    Input->>SharedDown: produce shared down-projection inputs
    SharedDown->>SharedDown: apply shared down-projection
    SharedDown->>PerExpert: split into per-expert activations
    PerExpert->>PerExpert: each lora_up[i] produces expert outputs
    PerExpert->>Aggregator: emit per-expert LoRA outputs
    Aggregator->>Aggregator: weight/aggregate using tokens_per_expert\nand permuted_probs
    Aggregator->>Base: add aggregated LoRA output to base output
    Base->>Input: return final output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Support LoRA for MOE Megatron SequentialMLP' directly and accurately describes the main change: adding LoRA support for Megatron's SequentialMLP component in Mixture-of-Experts architectures.
Security Anti-Patterns ✅ Passed Pull request introduces no security anti-patterns; all new code follows secure practices without unsafe deserialization, hardcoded trust flags, eval/exec calls, or prohibited bypass comments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jennifchen/moe_lora
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.

OpenGrep is compatible with Semgrep configurations. Add an opengrep.yml or semgrep.yml configuration file to your project to enable OpenGrep analysis.

@jenchen13 jenchen13 requested review from jingyu-ml and sychen52 March 5, 2026 17:04
expert_model_parallel_size=expert_model_parallel_size,
expert_tensor_parallel_size=expert_tensor_parallel_size,
sequence_parallel=False,
sequence_parallel=num_moe_experts > 0 and tensor_model_parallel_size > 1,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mcore throws an error if you use MOE + TP without SP

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Line 336: Replace the direct Megatron-Core import of
ensure_metadata_has_dp_cp_group (currently imported from
megatron.core.transformer.utils) with the version-stable compatibility helper
used by other plugins: import ensure_metadata_has_dp_cp_group from
modelopt.torch.opt.plugins.megatron instead; update the import statement in the
file so all usages of ensure_metadata_has_dp_cp_group in this module use the
compatibility wrapper rather than the unstable megatron.core.transformer.utils
symbol.

In `@tests/_test_utils/torch/megatron/models.py`:
- Line 172: The comparison on line using "sequence_parallel=num_moe_experts > 0
and tensor_model_parallel_size > 1" can raise TypeError when num_moe_experts is
None; change the condition to guard against None by evaluating num_moe_experts
is not None and num_moe_experts > 0 (or equivalently coerce with a default like
(num_moe_experts or 0) > 0) before combining with tensor_model_parallel_size > 1
so sequence_parallel is always a bool; update the call/site where
sequence_parallel is set (the argument named sequence_parallel in the model
construction) to use this guarded condition.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7cdca600-b524-488d-a5cb-940d5ede1314

📥 Commits

Reviewing files that changed from the base of the PR and between 31f0783 and 03b5140.

📒 Files selected for processing (3)
  • modelopt/torch/peft/lora/plugins/megatron.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py

@codecov
Copy link

codecov bot commented Mar 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.09%. Comparing base (e8f9687) to head (962df9b).
⚠️ Report is 29 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #979      +/-   ##
==========================================
- Coverage   72.12%   70.09%   -2.04%     
==========================================
  Files         209      221      +12     
  Lines       23628    25459    +1831     
==========================================
+ Hits        17042    17845     +803     
- Misses       6586     7614    +1028     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/peft/lora/plugins/megatron.py (1)

336-336: ⚠️ Potential issue | 🟠 Major

Use the ModelOpt compatibility helper for metadata normalization.

Line 336 still imports ensure_metadata_has_dp_cp_group directly from Megatron-Core. This is version-fragile; please switch to the compatibility helper used in other ModelOpt plugins.

🔧 Proposed import change
-        from megatron.core.transformer import ensure_metadata_has_dp_cp_group
+        from modelopt.torch.opt.plugins.megatron import ensure_metadata_has_dp_cp_group
#!/bin/bash
set -euo pipefail

# Verify compatibility helper exists and inspect current usages.
rg -n "def ensure_metadata_has_dp_cp_group" modelopt/torch/opt/plugins/megatron.py
rg -n "ensure_metadata_has_dp_cp_group" modelopt/torch -g '*.py'

Expected: helper definition is present in modelopt/torch/opt/plugins/megatron.py, and this file should align with that import pattern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/plugins/megatron.py` at line 336, The file imports
ensure_metadata_has_dp_cp_group directly from megatron.core.transformer which is
version-fragile; replace that direct import with the ModelOpt compatibility
helper used elsewhere (the helper defined in
modelopt/torch/opt/plugins/megatron.py) so metadata normalization uses the
centralized compatibility wrapper. Locate the import line referencing
ensure_metadata_has_dp_cp_group in modelopt/torch/peft/lora/plugins/megatron.py
and change it to import the helper from the ModelOpt plugin module (matching
other plugins' import pattern), ensuring all usages of
ensure_metadata_has_dp_cp_group in this file call the compatibility helper
instead of the Megatron-Core symbol.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Around line 373-383: The lora_up_key is reused for every local expert when
singleton_local_shards is False, causing earlier shards to be overwritten in
sharded_state_dict; update the key generation so it is unique per expert (e.g.,
include expert_global_idx or another expert identifier in the f-string) when
building lora_up_key in the block that sets up_offsets/up_offsets and calls
ShardedTensor.from_rank_offsets (refer to variables lora_up_key,
singleton_local_shards, up_offsets, sharded_offsets, expert_global_idx,
adapter_name, and the call to ShardedTensor.from_rank_offsets) so each expert
writes to a distinct dict key and no shards are dropped.

---

Duplicate comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Line 336: The file imports ensure_metadata_has_dp_cp_group directly from
megatron.core.transformer which is version-fragile; replace that direct import
with the ModelOpt compatibility helper used elsewhere (the helper defined in
modelopt/torch/opt/plugins/megatron.py) so metadata normalization uses the
centralized compatibility wrapper. Locate the import line referencing
ensure_metadata_has_dp_cp_group in modelopt/torch/peft/lora/plugins/megatron.py
and change it to import the helper from the ModelOpt plugin module (matching
other plugins' import pattern), ensuring all usages of
ensure_metadata_has_dp_cp_group in this file call the compatibility helper
instead of the Megatron-Core symbol.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 5222a6e7-5d85-4fc3-b56d-4af312e8fe36

📥 Commits

Reviewing files that changed from the base of the PR and between 03b5140 and d3827e0.

📒 Files selected for processing (2)
  • modelopt/torch/peft/lora/plugins/megatron.py
  • tests/_test_utils/torch/megatron/models.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/_test_utils/torch/megatron/models.py

Comment on lines +373 to +383
if singleton_local_shards:
lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
up_offsets = sharded_offsets
else:
lora_up_key = f"{prefix}lora_b_{adapter_name}.weight"
up_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)

sharded_state_dict[lora_up_key] = ShardedTensor.from_rank_offsets(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Prevent silent shard loss in lora_up checkpoint mapping.

At Line 377, lora_up_key is identical for every local expert when singleton_local_shards is False, so Line 383 overwrites earlier experts in sharded_state_dict. This drops shards for all but the last expert.

💡 Minimal fix to avoid key overwrite
                 if singleton_local_shards:
                     lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
                     up_offsets = sharded_offsets
                 else:
-                    lora_up_key = f"{prefix}lora_b_{adapter_name}.weight"
+                    lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
                     up_offsets = (
                         *sharded_offsets,
                         (len(sharded_offsets), expert_global_idx, num_global_experts),
                     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/plugins/megatron.py` around lines 373 - 383, The
lora_up_key is reused for every local expert when singleton_local_shards is
False, causing earlier shards to be overwritten in sharded_state_dict; update
the key generation so it is unique per expert (e.g., include expert_global_idx
or another expert identifier in the f-string) when building lora_up_key in the
block that sets up_offsets/up_offsets and calls ShardedTensor.from_rank_offsets
(refer to variables lora_up_key, singleton_local_shards, up_offsets,
sharded_offsets, expert_global_idx, adapter_name, and the call to
ShardedTensor.from_rank_offsets) so each expert writes to a distinct dict key
and no shards are dropped.

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py (1)

179-218: ⚠️ Potential issue | 🟡 Minor

Inconsistent parameter passing between meta_device and CUDA branches.

The meta_device=True branch (lines 190-204) does not pass moe_grouped_gemm to get_mcore_gpt_model, while the CUDA branch (lines 206-218) does pass it. This could lead to inconsistent behavior when testing with meta device initialization.

🔧 Suggested fix
     if meta_device:
         with torch.device("meta"):
             gpt_model = get_mcore_gpt_model(
                 tensor_model_parallel_size=tp_size,
                 num_layers=2,
                 ffn_hidden_size=None,
                 num_attention_heads=4,
                 activation_func="squared_relu",
                 transformer_impl="local",
                 use_te=use_te,
                 hidden_size=hidden_size,
                 vocab_size=vocab_size,
                 use_cpu_initialization=meta_device,
                 num_moe_experts=num_moe_experts,
+                moe_grouped_gemm=moe_grouped_gemm,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py` around lines 179
- 218, _gpt_model_provider creates models inconsistently: when meta_device=True
it calls get_mcore_gpt_model without passing the moe_grouped_gemm argument
whereas the non-meta (CUDA) branch passes moe_grouped_gemm; update the
meta_device branch call to include moe_grouped_gemm=moe_grouped_gemm so both
branches call get_mcore_gpt_model with the same set of parameters (reference
get_mcore_gpt_model, meta_device, moe_grouped_gemm, and _gpt_model_provider).
♻️ Duplicate comments (1)
modelopt/torch/peft/lora/plugins/megatron.py (1)

388-404: ⚠️ Potential issue | 🔴 Critical

Dict key collision causes silent shard loss for lora_up weights.

When singleton_local_shards is False, line 392 generates the same lora_up_key for every local expert (f"{prefix}lora_b_{adapter_name}.weight"). Since sharded_state_dict is a Python dict, line 398 overwrites entries from previous experts, causing all but the last expert's lora_up weights to be lost during checkpointing.

🐛 Suggested fix
                 if singleton_local_shards:
                     lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
                     up_offsets = sharded_offsets
                 else:
-                    lora_up_key = f"{prefix}lora_b_{adapter_name}.weight"
+                    lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_local_idx}.weight"
                     up_offsets = (
                         *sharded_offsets,
                         (len(sharded_offsets), expert_global_idx, num_global_experts),
                     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/plugins/megatron.py` around lines 388 - 404, The
dict key for lora_up is colliding when singleton_local_shards is False because
lora_up_key is built as f"{prefix}lora_b_{adapter_name}.weight" for every local
expert; change that branch to include the expert identifier (e.g., use
f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight" like the singleton
branch) so each expert gets a unique key before calling
ShardedTensor.from_rank_offsets (ensure this adjustment is made alongside the
existing up_offsets logic involving expert_global_idx and num_global_experts).
🧹 Nitpick comments (1)
modelopt/torch/peft/lora/config.py (1)

39-48: Consider using dataclasses or Pydantic for configuration objects.

The configs are defined as plain dictionaries. As per coding guidelines for modelopt/torch/**/config.py: "Use dataclasses or Pydantic for mode configuration objects." While dict-based configs work and align with existing usage patterns in the codebase (e.g., adapter_cfg dict format), consider whether a typed config class would improve validation and IDE support.

Given that these configs are passed to update_model which already expects dict-like structures, the current approach may be intentional for compatibility.

Also applies to: 67-77

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/config.py` around lines 39 - 48, The DENSE_LORA_CFG
(and the similar config at lines 67-77) are plain dicts; convert them into a
typed configuration class (use a dataclass or Pydantic model named e.g.,
LoraConfig with an adapter_cfg field) to provide validation and IDE type hints
while preserving the existing dict shape passed to update_model; implement a
to_dict/asdict/.dict() method and update call sites that pass DENSE_LORA_CFG to
call that conversion (or keep a module-level constant DENSE_LORA_CFG =
LoraConfig(...).to_dict() if you need immediate backward compatibility), and
ensure fields like adapter_cfg and nested patterns (e.g., "*linear_qkv*") remain
unchanged in the resulting dict.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py`:
- Around line 179-218: _gpt_model_provider creates models inconsistently: when
meta_device=True it calls get_mcore_gpt_model without passing the
moe_grouped_gemm argument whereas the non-meta (CUDA) branch passes
moe_grouped_gemm; update the meta_device branch call to include
moe_grouped_gemm=moe_grouped_gemm so both branches call get_mcore_gpt_model with
the same set of parameters (reference get_mcore_gpt_model, meta_device,
moe_grouped_gemm, and _gpt_model_provider).

---

Duplicate comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Around line 388-404: The dict key for lora_up is colliding when
singleton_local_shards is False because lora_up_key is built as
f"{prefix}lora_b_{adapter_name}.weight" for every local expert; change that
branch to include the expert identifier (e.g., use
f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight" like the singleton
branch) so each expert gets a unique key before calling
ShardedTensor.from_rank_offsets (ensure this adjustment is made alongside the
existing up_offsets logic involving expert_global_idx and num_global_experts).

---

Nitpick comments:
In `@modelopt/torch/peft/lora/config.py`:
- Around line 39-48: The DENSE_LORA_CFG (and the similar config at lines 67-77)
are plain dicts; convert them into a typed configuration class (use a dataclass
or Pydantic model named e.g., LoraConfig with an adapter_cfg field) to provide
validation and IDE type hints while preserving the existing dict shape passed to
update_model; implement a to_dict/asdict/.dict() method and update call sites
that pass DENSE_LORA_CFG to call that conversion (or keep a module-level
constant DENSE_LORA_CFG = LoraConfig(...).to_dict() if you need immediate
backward compatibility), and ensure fields like adapter_cfg and nested patterns
(e.g., "*linear_qkv*") remain unchanged in the resulting dict.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b3f605d6-8b34-49b5-b97a-6e7a981044d7

📥 Commits

Reviewing files that changed from the base of the PR and between d3827e0 and 7583222.

📒 Files selected for processing (4)
  • modelopt/torch/peft/lora/config.py
  • modelopt/torch/peft/lora/plugins/megatron.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/_test_utils/torch/megatron/models.py

Signed-off-by: jenchen13 <jennifchen@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/peft/lora/plugins/megatron.py (1)

386-402: ⚠️ Potential issue | 🔴 Critical

lora_up checkpoint mapping still has key collision when singleton_local_shards=False.

When singleton_local_shards is False and num_local_experts > 1, Line 390 produces an identical lora_up_key for every local expert (f"{prefix}lora_b_{adapter_name}.weight"). Since Line 396 assigns to sharded_state_dict[lora_up_key] inside the loop, only the last expert's weights survive—all earlier experts are silently overwritten.

The key must be unique per expert to avoid losing shards:

Proposed fix to include expert index in key
                 if singleton_local_shards:
                     lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight"
                     up_offsets = sharded_offsets
                 else:
-                    lora_up_key = f"{prefix}lora_b_{adapter_name}.weight"
+                    lora_up_key = f"{prefix}lora_b_{adapter_name}.{expert_local_idx}.weight"
                     up_offsets = (
                         *sharded_offsets,
                         (len(sharded_offsets), expert_global_idx, num_global_experts),
                     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/peft/lora/plugins/megatron.py` around lines 386 - 402, The
lora_up_key built when singleton_local_shards is False is identical for all
local experts and causes sharded_state_dict entries to be overwritten; update
the key construction in the else branch to include the expert index (e.g.,
expert_global_idx or expert_local_idx) so each expert gets a unique key (e.g.,
f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight") and use that same
key when creating the ShardedTensor.from_rank_offsets call so every expert's
up.weight is stored under a distinct sharded_state_dict entry.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/peft/lora/plugins/megatron.py`:
- Around line 386-402: The lora_up_key built when singleton_local_shards is
False is identical for all local experts and causes sharded_state_dict entries
to be overwritten; update the key construction in the else branch to include the
expert index (e.g., expert_global_idx or expert_local_idx) so each expert gets a
unique key (e.g., f"{prefix}lora_b_{adapter_name}.{expert_global_idx}.weight")
and use that same key when creating the ShardedTensor.from_rank_offsets call so
every expert's up.weight is stored under a distinct sharded_state_dict entry.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 814e762a-5a11-499d-b51a-8a69e258a039

📥 Commits

Reviewing files that changed from the base of the PR and between 7583222 and b80c82d.

📒 Files selected for processing (2)
  • modelopt/torch/peft/lora/config.py
  • modelopt/torch/peft/lora/plugins/megatron.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/peft/lora/config.py

Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant