Skip to content

[None][fix] AutoDeploy: handle torch dist all_gather in multi_stream MLA transform#15456

Open
MrGeva wants to merge 2 commits into
NVIDIA:mainfrom
nv-auto-deploy:fix/multi-stream-attn-allgather-backend
Open

[None][fix] AutoDeploy: handle torch dist all_gather in multi_stream MLA transform#15456
MrGeva wants to merge 2 commits into
NVIDIA:mainfrom
nv-auto-deploy:fix/multi-stream-attn-allgather-backend

Conversation

@MrGeva

@MrGeva MrGeva commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Pattern 0 was rebuilding KV all_gather with the trtllm argument layout for every backend. Branch on the op type so torch_dist_all_gather keeps its (tensor, dim, sizes) signature while trtllm_dist_all_gather still gets workspace_id on the aux stream.

Summary by CodeRabbit

  • Refactor
    • Enhanced multi-stream attention with improved backend-aware handling of auxiliary stream operations, ensuring proper argument configuration for different backend implementations while simplifying the overall code flow.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Pattern 0 was rebuilding KV all_gather with the trtllm argument layout for
every backend. Branch on the op type so torch_dist_all_gather keeps its
(tensor, dim, sizes) signature while trtllm_dist_all_gather still gets
workspace_id on the aux stream.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@MrGeva MrGeva changed the title [None][fix] handle torch dist all_gather in multi_stream MLA pattern 0 [None][fix] AutoDeploy: handle torch dist all_gather in multi_stream MLA transform Jun 17, 2026
@MrGeva MrGeva marked this pull request as ready for review June 17, 2026 11:31
@MrGeva MrGeva requested a review from a team as a code owner June 17, 2026 11:31
@MrGeva MrGeva requested a review from QiJune June 17, 2026 11:31
@coderabbitai

coderabbitai Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

A new private helper _build_aux_stream_all_gather_args is added to multi_stream_attn.py to construct backend-specific argument lists for the aux-stream KV all_gather call, branching on trtllm_dist_all_gather vs torch_dist_all_gather. Pattern 0's aux-stream kv_ag construction is updated to use this helper, removing its previous inline parameter extraction.

Changes

Aux-stream all_gather argument builder

Layer / File(s) Summary
Helper definition and Pattern 0 call site
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_attn.py
_build_aux_stream_all_gather_args is added to produce backend-specific arg lists: for trtllm_dist_all_gather it includes strategy/dim/sizes and forces workspace_id=_AUX_WORKSPACE_ID; for torch_dist_all_gather it preserves the original dim/sizes args and omits strategy/workspace. Pattern 0's aux-stream kv_ag node construction drops its manual parameter extraction and delegates to this helper.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description explains the problem and solution, but the required template sections (Description, Test Coverage, PR Checklist) are incomplete with only placeholder comments and unchecked checkboxes. Complete the Description section explaining the issue/solution, add Test Coverage details, and fill out or confirm the PR Checklist items before merging.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the specific issue (torch dist all_gather handling) and the affected component (multi_stream MLA transform), accurately summarizing the main fix.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_attn.py (1)

188-189: 💤 Low value

Prefer tuple unpacking over concatenation.

The static analysis tool suggests a cleaner form. Since kv_ag.args is already a tuple, unpacking is more idiomatic and avoids the redundant tuple() call.

♻️ Suggested simplification
     # torch_dist_all_gather — preserve dim/sizes, no strategy or workspace_id.
-    return (new_input,) + tuple(kv_ag.args[1:])
+    return (new_input, *kv_ag.args[1:])
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_attn.py`
around lines 188 - 189, In the return statement, the code uses tuple
concatenation with `(new_input,) + tuple(kv_ag.args[1:])` where the tuple() call
is redundant since slicing kv_ag.args already returns a tuple. Replace this with
tuple unpacking syntax using `(new_input, *kv_ag.args[1:])` which is more
idiomatic and avoids the unnecessary tuple() conversion.

Source: Linters/SAST tools

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_attn.py`:
- Around line 188-189: In the return statement, the code uses tuple
concatenation with `(new_input,) + tuple(kv_ag.args[1:])` where the tuple() call
is redundant since slicing kv_ag.args already returns a tuple. Replace this with
tuple unpacking syntax using `(new_input, *kv_ag.args[1:])` which is more
idiomatic and avoids the unnecessary tuple() conversion.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 1889c173-26d0-479c-a2aa-616179282cb9

📥 Commits

Reviewing files that changed from the base of the PR and between 0ffa09f and 79cf0ae.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_attn.py

Cover trtllm and torch dist all_gather arg rewriting for multi_stream
MLA pattern 0 via hand-built FX graphs.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@MrGeva MrGeva requested a review from suyoggupta June 17, 2026 16:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants