Skip to content
Open
2 changes: 1 addition & 1 deletion nemo/collections/tts/data/text_to_speech_dataset_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float)
# Use IPA text for IPABPETokenizer (required), otherwise use regular text_str
if isinstance(self.phoneme_tokenizer, IPABPETokenizer):
if not cut.supervisions[0].has_custom("ipa"):
if self.dataset_type == 'train':
if (self.dataset_type == 'train') and (language not in self.ignore_phoneme_languages):
raise ValueError(
f"IPABPETokenizer requires 'ipa' field but it is not available in the cut. "
f"Cut ID: {cut.id}, Text: {text_str}"
Expand Down
64 changes: 64 additions & 0 deletions nemo/collections/tts/models/easy_magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import time
from dataclasses import dataclass, fields
from functools import partial
Expand Down Expand Up @@ -41,6 +42,7 @@
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging
from nemo.utils.exceptions import NeMoBaseException


@dataclass
Expand Down Expand Up @@ -362,6 +364,46 @@
else:
self.audio_in_projection = nn.Identity()

# Speaker/context encoder for context audio embeddings.
# This enables keeping the zero-shot conditioning module private at release time.

self.use_speaker_encoder = cfg.get('use_speaker_encoder', False)
self.train_shuffle_context_embedding_prob = 0.0
if self.use_speaker_encoder:
speaker_encoder_cfg = cfg.get('speaker_encoder', None)
if speaker_encoder_cfg is not None:
speaker_encoder_cfg = dict(speaker_encoder_cfg)
if speaker_encoder_cfg.get('use_moe', False):
raise NeMoBaseException(
"MoE is not recommended for the speaker encoder. "
"Please set speaker_encoder.use_moe to False."
)
if 'router_load_balancing_loss_coeff' in speaker_encoder_cfg:
logging.warning(
"Detected `router_load_balancing_loss_coeff` in speaker encoder config. "
"MoE is not recommended for the speaker encoder."
)
if 'router_z_loss_coeff' in speaker_encoder_cfg:
logging.warning(
"Detected `router_z_loss_coeff` in speaker encoder config. "
"MoE is not recommended for the speaker encoder."
)
else:
speaker_encoder_cfg = {
'n_layers': cfg.get('speaker_encoder_n_layers', 1),
'd_model': cfg.embedding_dim,
'd_ffn': cfg.get('speaker_encoder_d_ffn', cfg.embedding_dim * 2),
'sa_n_heads': cfg.get('speaker_encoder_n_heads', 12),
'kernel_size': cfg.get('speaker_encoder_kernel_size', 1),
'p_dropout': cfg.get('speaker_encoder_p_dropout', 0.0),
'is_causal': False,
'use_learnable_pos_emb': True,
}
self.speaker_encoder = transformer_2501.Transformer(**speaker_encoder_cfg)
# Train-only probability to bypass speaker encoder and feed batch-shuffled
# raw context embeddings, matching Magpie behavior.
self.train_shuffle_context_embedding_prob = cfg.get('train_shuffle_context_embedding_prob', 0.0)

if self.phoneme_tokenizer is not None:
phoneme_embeddings = []
for _ in range(self.phoneme_stacking_factor):
Expand Down Expand Up @@ -591,6 +633,11 @@
audio_embedding = self.audio_in_projection(audio_embedding)
return audio_embedding

def encode_context_audio_embeddings(self, context_audio_embedded: torch.Tensor, context_audio_lens: torch.Tensor):
"""Encode context audio embeddings with the speaker encoder."""
context_mask = get_mask_from_lengths(context_audio_lens)
return self.speaker_encoder(context_audio_embedded, context_mask, cond=None, cond_mask=None)['output']

def embed_phoneme_tokens(self, phoneme_tokens):
# phoneme_tokens: (B, S, T')
phoneme_embedding = None
Expand Down Expand Up @@ -793,7 +840,7 @@
Required if context_audio is provided.
training_mode: Optional TrainingMode object specifying the mode to use.
If None, uses the first mode from training_modes as default.
dropout_conditional_input: If True, replace context with CFG unconditional token.

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'batch_size' is unnecessary as it is
redefined
before this value is used.

Returns:
Tuple of:
Expand Down Expand Up @@ -839,6 +886,23 @@
self.num_audio_codebooks,
)
context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T', E)
batch_size = context_audio_embedded.size(0)

if self.use_speaker_encoder:
if (
self.training
and batch_size > 1
and self.train_shuffle_context_embedding_prob > 0
and random.random() < self.train_shuffle_context_embedding_prob
):
# Feed shuffled raw context embeddings (without speaker encoder) so
# the decoder cannot rely on direct unencoded speaker identity cues.
shift = random.randint(1, batch_size - 1)
context_audio_embedded = context_audio_embedded.roll(shift, dims=0)
else:
context_audio_embedded = self.encode_context_audio_embeddings(
context_audio_embedded=context_audio_embedded, context_audio_lens=context_audio_codes_lens
)

# Context Text
context_text_lens = context_text_tokens_lens
Expand Down
38 changes: 32 additions & 6 deletions nemo/collections/tts/models/magpietts.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging
from nemo.utils.exceptions import NeMoBaseException


@dataclass
Expand Down Expand Up @@ -592,11 +593,23 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
for layer in self.context_decoder_layers:
multi_encoder_mapping[layer] = 1
self.multi_encoder_mapping = multi_encoder_mapping
# Create context encoder (filter out MoE loss coefficients if present)
# Note: Loss coefficients are model-level config, not passed to Transformer module
# Create context encoder.
# Note: router_* loss coefficients are model-level config, not consumed by the Transformer module.
context_encoder_cfg = dict(cfg.context_encoder)
context_encoder_cfg.pop('router_load_balancing_loss_coeff', None)
context_encoder_cfg.pop('router_z_loss_coeff', None)
if context_encoder_cfg.get('use_moe', False):
raise NeMoBaseException(
"MoE is not recommended for the context encoder. Please set context_encoder.use_moe to False."
)
if 'router_load_balancing_loss_coeff' in context_encoder_cfg:
logging.warning(
"Detected `router_load_balancing_loss_coeff` in context encoder config. "
"MoE is not recommended for the context encoder."
)
if 'router_z_loss_coeff' in context_encoder_cfg:
logging.warning(
"Detected `router_z_loss_coeff` in context encoder config. "
"MoE is not recommended for the context encoder."
)
self.context_encoder = transformer_2501.Transformer(**context_encoder_cfg)
elif self.model_type == 'decoder_context_tts':
# Context audio/text goes directly to the decoder (before the target audio codes)
Expand All @@ -606,9 +619,22 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
elif self.model_type == 'decoder_ce':
# Similar to decoder_context_tts, but we use context encoder
# Decoder gets output from context encoder instead of raw context tokens embeddings
# Note: router_* loss coefficients are model-level config, not consumed by the Transformer module.
context_encoder_cfg = dict(cfg.context_encoder)
context_encoder_cfg.pop('router_load_balancing_loss_coeff', None)
context_encoder_cfg.pop('router_z_loss_coeff', None)
if context_encoder_cfg.get('use_moe', False):
raise NeMoBaseException(
"MoE is not recommended for the context encoder. Please set context_encoder.use_moe to False."
)
if 'router_load_balancing_loss_coeff' in context_encoder_cfg:
logging.warning(
"Detected `router_load_balancing_loss_coeff` in context encoder config. "
"MoE is not recommended for the context encoder."
)
if 'router_z_loss_coeff' in context_encoder_cfg:
logging.warning(
"Detected `router_z_loss_coeff` in context encoder config. "
"MoE is not recommended for the context encoder."
)
self.context_encoder = transformer_2501.Transformer(**context_encoder_cfg)
self.transcript_decoder_layers = [
idx for idx in range(cfg.decoder.n_layers)
Expand Down
Loading