Skip to content

Add Sound Encoder to Cosmos3#13911

Draft
MaciejBalaNV wants to merge 1 commit into
huggingface:mainfrom
MaciejBalaNV:cosmos3_sound_encoder
Draft

Add Sound Encoder to Cosmos3#13911
MaciejBalaNV wants to merge 1 commit into
huggingface:mainfrom
MaciejBalaNV:cosmos3_sound_encoder

Conversation

@MaciejBalaNV

Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Maciej Bala <mbala@nvidia.com>
@github-actions github-actions Bot added models tests size/L PR with diff > 200 LOC labels Jun 10, 2026
Comment on lines +617 to +626
def _disable_encoder(self):
self.encoder = None
self._encoder_available = False
self.register_to_config(encoder_enabled=False)

def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
super()._fix_state_dict_keys_on_load(state_dict)
if self.encoder is not None and not any(key.startswith("encoder.") for key in state_dict):
self._disable_encoder()

@yiyixuxu yiyixuxu Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these two methods?

@yiyixuxu yiyixuxu requested a review from dg845 June 10, 2026 20:37
return hidden_states


class Cosmos3AudioSnakeBeta(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the existing Snake1d module implements essentially the same logic as Cosmos3AudioSnakeBeta, could we use it as well for the encoder?

return hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)


class Cosmos3AudioLayerNorm(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we potentially reuse the existing diffusers.models.normalization.FP32LayerNorm module here? Like Cosmos3AudioLayerNorm, it also upcasts the weight and bias (if available) to FP32:

return F.layer_norm(
inputs.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
).to(origin_dtype)

Comment on lines +135 to +139
self.pwconv2 = (
_zero_module(nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1))
if identity_init
else nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1)
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.pwconv2 = (
_zero_module(nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1))
if identity_init
else nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1)
)
self.pwconv2 = nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1)
if identity_init:
nn.init.zeros_(self.pwconv2.weight)
nn.init.zeros_(self.pwconv2.bias)

I think the above suggestion would be more clear and would allow us to remove the _zero_module helper method, as we prefer not to have too many small methods.

Comment on lines +40 to +45
def _zero_module(module: nn.Module) -> nn.Module:
for parameter in module.parameters():
parameter.detach().zero_()
return module


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _zero_module(module: nn.Module) -> nn.Module:
for parameter in module.parameters():
parameter.detach().zero_()
return module

Follow up suggestion to #13911 (comment).

if num_channels > 1:
audio = audio.reshape(batch_size * num_channels, 1, num_samples)

with torch.autocast(device_type=audio.device.type, enabled=False):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to remove the autocast region here? We generally prefer not to use autocast regions and I think mixed-precision training would still work correctly without it (since _spectrogram doesn't use any ops that could be dispatched to a lower precision dtype).

return hidden_state


class Cosmos3AudioDiagonalGaussianDistribution:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to reuse the existing OobleckDiagonalGaussianDistribution module here? I believe the logic is essentially the same as in Cosmos3AudioDiagonalGaussianDistribution.

Comment on lines +662 to +663
encoder_dtype = next(self.encoder.parameters()).dtype if self.encoder is not None else hidden_states.dtype
moments = self._encode(hidden_states.to(dtype=encoder_dtype))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_dtype = next(self.encoder.parameters()).dtype if self.encoder is not None else hidden_states.dtype
moments = self._encode(hidden_states.to(dtype=encoder_dtype))
encoder_dtype = get_parameter_dtype(self.encoder) if self.encoder is not None else hidden_states.dtype
moments = self._encode(hidden_states.to(dtype=encoder_dtype))

Using diffusers.models.modeling_utils.get_parameter_dtype here is more robust to things like layerwise casting, where the storage dtype (for self.encoder's weights) may differ from the compute dtype (which we want hidden_states to be in).

Comment on lines +628 to +632
if self.encoder is None or not self._encoder_available:
raise ValueError(
"This Cosmos3 AVAE sound tokenizer was loaded from decoder-only weights and cannot encode audio. "
"Re-convert the AVAE checkpoint with encoder weights to use `encode()`."
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might make more sense to move this check into encode, so that we fail earlier.

@dg845 dg845 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial design review :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants