Add Sound Encoder to Cosmos3#13911
Conversation
Signed-off-by: Maciej Bala <mbala@nvidia.com>
| def _disable_encoder(self): | ||
| self.encoder = None | ||
| self._encoder_available = False | ||
| self.register_to_config(encoder_enabled=False) | ||
|
|
||
| def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None: | ||
| super()._fix_state_dict_keys_on_load(state_dict) | ||
| if self.encoder is not None and not any(key.startswith("encoder.") for key in state_dict): | ||
| self._disable_encoder() | ||
|
|
There was a problem hiding this comment.
why do we need these two methods?
| return hidden_states | ||
|
|
||
|
|
||
| class Cosmos3AudioSnakeBeta(nn.Module): |
There was a problem hiding this comment.
It looks like the existing Snake1d module implements essentially the same logic as Cosmos3AudioSnakeBeta, could we use it as well for the encoder?
| return hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) | ||
|
|
||
|
|
||
| class Cosmos3AudioLayerNorm(nn.Module): |
There was a problem hiding this comment.
Could we potentially reuse the existing diffusers.models.normalization.FP32LayerNorm module here? Like Cosmos3AudioLayerNorm, it also upcasts the weight and bias (if available) to FP32:
diffusers/src/diffusers/models/normalization.py
Lines 87 to 93 in 0cc1cdb
| self.pwconv2 = ( | ||
| _zero_module(nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1)) | ||
| if identity_init | ||
| else nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1) | ||
| ) |
There was a problem hiding this comment.
| self.pwconv2 = ( | |
| _zero_module(nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1)) | |
| if identity_init | |
| else nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1) | |
| ) | |
| self.pwconv2 = nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1) | |
| if identity_init: | |
| nn.init.zeros_(self.pwconv2.weight) | |
| nn.init.zeros_(self.pwconv2.bias) |
I think the above suggestion would be more clear and would allow us to remove the _zero_module helper method, as we prefer not to have too many small methods.
| def _zero_module(module: nn.Module) -> nn.Module: | ||
| for parameter in module.parameters(): | ||
| parameter.detach().zero_() | ||
| return module | ||
|
|
||
|
|
There was a problem hiding this comment.
| def _zero_module(module: nn.Module) -> nn.Module: | |
| for parameter in module.parameters(): | |
| parameter.detach().zero_() | |
| return module |
Follow up suggestion to #13911 (comment).
| if num_channels > 1: | ||
| audio = audio.reshape(batch_size * num_channels, 1, num_samples) | ||
|
|
||
| with torch.autocast(device_type=audio.device.type, enabled=False): |
There was a problem hiding this comment.
Would it be possible to remove the autocast region here? We generally prefer not to use autocast regions and I think mixed-precision training would still work correctly without it (since _spectrogram doesn't use any ops that could be dispatched to a lower precision dtype).
| return hidden_state | ||
|
|
||
|
|
||
| class Cosmos3AudioDiagonalGaussianDistribution: |
There was a problem hiding this comment.
Would it be possible to reuse the existing OobleckDiagonalGaussianDistribution module here? I believe the logic is essentially the same as in Cosmos3AudioDiagonalGaussianDistribution.
| encoder_dtype = next(self.encoder.parameters()).dtype if self.encoder is not None else hidden_states.dtype | ||
| moments = self._encode(hidden_states.to(dtype=encoder_dtype)) |
There was a problem hiding this comment.
| encoder_dtype = next(self.encoder.parameters()).dtype if self.encoder is not None else hidden_states.dtype | |
| moments = self._encode(hidden_states.to(dtype=encoder_dtype)) | |
| encoder_dtype = get_parameter_dtype(self.encoder) if self.encoder is not None else hidden_states.dtype | |
| moments = self._encode(hidden_states.to(dtype=encoder_dtype)) |
Using diffusers.models.modeling_utils.get_parameter_dtype here is more robust to things like layerwise casting, where the storage dtype (for self.encoder's weights) may differ from the compute dtype (which we want hidden_states to be in).
| if self.encoder is None or not self._encoder_available: | ||
| raise ValueError( | ||
| "This Cosmos3 AVAE sound tokenizer was loaded from decoder-only weights and cannot encode audio. " | ||
| "Re-convert the AVAE checkpoint with encoder weights to use `encode()`." | ||
| ) |
There was a problem hiding this comment.
I think it might make more sense to move this check into encode, so that we fail earlier.
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left an initial design review :).
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.