[TRTLLM-13120][feat] Cosmos3 Audio Output Support#14827
Conversation
📝 WalkthroughWalkthroughThis PR adds audio generation support to Cosmos3OmniMoTPipeline alongside video. Audio generation is gated by an ChangesCosmos3 Audio Generation Integration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py (1)
1098-1118:⚠️ Potential issue | 🟠 Major | ⚡ Quick winNormalize the
model.prefix before matching audio keys.The new audio remaps run before Lines 1116-1118 strip a leading
model.prefix, so checkpoint entries likemodel.audio_proj_in.weightandmodel.audio_modality_embedfall through and get skipped. That leaves the audio path partially unloaded for checkpoints using the same prefix convention as the rest of the transformer weights.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py` around lines 1098 - 1118, The audio-related checkpoint keys (prefixes "audio_proj_in.", "audio_proj_out.", "audio_modality_embed", "time_embedder.linear") are being checked before the code strips a leading "model." prefix, so entries like "model.audio_proj_in.weight" are missed; fix by normalizing the key early (e.g., update the variable k by stripping "model." when present) before the audio remap checks in the same block (ensure the "model." handling around k = k[len("model.") :] happens before the checks for "audio_proj_in.", "audio_proj_out.", "audio_modality_embed", and "time_embedder.linear") so all audio keys with or without the "model." prefix are correctly remapped.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/cosmos3/modules.py`:
- Around line 100-112: get_norm_module currently misses the "time_layer_norm"
case and incorrectly uses nn.LayerNorm(module.out_channels) directly (LayerNorm
expects the normalized dimension to be last), causing wrong behavior for tensors
shaped [N,C,T]; update get_norm_module to (1) add a branch for norm ==
"time_layer_norm" that returns a module which applies LayerNorm over the channel
dimension by permuting [N,C,T] -> [N,T,C], applying
nn.LayerNorm(module.out_channels), then permuting back, and (2) fix the existing
"layer_norm" branch to similarly wrap nn.LayerNorm so it normalizes channels
correctly for ConvNd outputs; keep the causal/group-norm check intact and still
assert module is an instance of nn.modules.conv._ConvNd and refer to
get_norm_module and CONV_NORMALIZATIONS when making the changes.
- Around line 115-129: The pad1d function's default mode "zero" is invalid for
torch.nn.functional.pad; update pad1d (function pad1d) to use a valid default
such as mode="constant" (which preserves the existing value parameter semantics)
and ensure any callers expecting "zero" continue to work by using the value
argument; adjust the function signature default and keep the existing
reflect-handling logic intact so F.pad is always invoked with a supported mode.
In `@tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py`:
- Around line 686-695: The audio config fallback uses getattr(...,
"<new_field>", pretrained_config.<legacy_field>) which eagerly evaluates the
legacy attribute and can raise AttributeError if legacy keys are missing; change
the audio field assignments in transformer_cosmos3 (audio_dim, audio_latent_fps,
temporal_compression_factor_audio) to safely check for legacy keys (e.g., use
nested getattr or hasattr/try-except to read pretrained_config.sound_* only if
present) so the fallback is evaluated lazily. Also in load_weights(), the
remapping for audio keys (audio_proj_in.*, audio_proj_out.*,
audio_modality_embed) happens before the code strips the leading "model."
prefix, so normalize checkpoint keys by removing the "model." prefix first (or
re-run the audio remap after normalization) so keys like "model.audio_proj_in.*"
correctly match the remapping branches.
---
Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py`:
- Around line 1098-1118: The audio-related checkpoint keys (prefixes
"audio_proj_in.", "audio_proj_out.", "audio_modality_embed",
"time_embedder.linear") are being checked before the code strips a leading
"model." prefix, so entries like "model.audio_proj_in.weight" are missed; fix by
normalizing the key early (e.g., update the variable k by stripping "model."
when present) before the audio remap checks in the same block (ensure the
"model." handling around k = k[len("model.") :] happens before the checks for
"audio_proj_in.", "audio_proj_out.", "audio_modality_embed", and
"time_embedder.linear") so all audio keys with or without the "model." prefix
are correctly remapped.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: ed333cb8-dd36-4030-acfd-3b1c5628f0c3
📒 Files selected for processing (7)
tensorrt_llm/_torch/visual_gen/models/cosmos3/defaults.pytensorrt_llm/_torch/visual_gen/models/cosmos3/modules.pytensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.pytensorrt_llm/_torch/visual_gen/models/cosmos3/sound_tokenizer.pytensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.pytensorrt_llm/_torch/visual_gen/pipeline.pytensorrt_llm/_torch/visual_gen/pipeline_registry.py
0c3d063 to
8aa0e91
Compare
350468f to
217f591
Compare
16bce31 to
8673740
Compare
|
/bot run |
|
PR_Github #53961 [ run ] triggered by Bot. Commit: |
|
PR_Github #53961 [ run ] completed with state
|
178450d to
ca7a790
Compare
|
/bot run |
|
PR_Github #55043 [ run ] triggered by Bot. Commit: |
|
PR_Github #55043 [ run ] completed with state
|
ca7a790 to
7d46330
Compare
|
/bot run |
|
PR_Github #55461 [ run ] triggered by Bot. Commit: |
|
PR_Github #55461 [ run ] completed with state
|
|
Hi @NVShreyas @rahul-steiger-nv, 15417 will require a small update here, but it should not block this PR. The scheduler/pipeline logic should still compare against raw 15417 only changes the transformer-forward contract:
So after 15417, this PR should update the transformer call from: timestep=timestep,
attention_timestep=timestep / self.scheduler.config.num_train_timesteps,to: timestep=timestep / self.scheduler.config.num_train_timesteps,
raw_timestep=timestep,This preserves the 15545 fix because Cosmos time embedding still uses the raw scheduler timestep. I can make the adjustment after this PR lands, or this PR can rebase after 15417 lands. Either path is fine. Thanks! |
Thanks, that sounds good. I’m happy to rebase after 15417 lands and make the small call-site update in this PR. |
|
/bot run |
|
PR_Github #55867 [ run ] triggered by Bot. Commit: |
|
PR_Github #56966 [ run ] completed with state
|
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
2f9da72 to
0601b25
Compare
…lision - test_cosmos3_example: point script_path at the reorganized examples/visual_gen/models/cosmos3/cosmos3.py (was flat cosmos3_ti2v.py) - _run_forward: accept height/width/guidance_scale overrides so test_t2i_smoke no longer passes duplicate keyword arguments to forward() Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #57013 [ run ] triggered by Bot. Commit: |
|
PR_Github #57013 [ run ] completed with state |
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Summary by CodeRabbit
Release Notes
enable_audioparameter enables audio generation per inference request with decoded audio waveforms.Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.