diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 858feda00fa0..8b777312e650 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -186,6 +186,14 @@ class TranscriptionConfig: # use_cer: bool = False debug_mode: bool = False # Whether to print more detail in the output. + # Language-ID prompt for prompt-conditioned models (e.g. EncDecRNNTBPEModelWithPrompt). + # Set to a language key from the model's prompt_dictionary (e.g. "en-US", "auto"). + # Ignored for models without prompt support. + target_lang: Optional[str] = None + # whether to strip the language tags from the transcriptions + # Ignored for model without prompt support + strip_lang_tags: bool = False + def extract_transcriptions(hyps): """ @@ -363,6 +371,13 @@ def main(cfg: TranscriptionConfig): else: asr_model.change_decoding_strategy(cfg.ctc_decoding) + # Set language-ID prompt for prompt-conditioned models + if hasattr(asr_model, 'set_inference_prompt'): + lang = cfg.target_lang if cfg.target_lang is not None else "auto" + asr_model.set_inference_prompt(lang) + asr_model.decoding.strip_lang_tags = cfg.strip_lang_tags + asr_model.decoding.set_strip_lang_tags(cfg.strip_lang_tags) + asr_model = asr_model.to(device=device, dtype=compute_dtype) asr_model.eval() diff --git a/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py new file mode 100644 index 000000000000..ae6895b60412 --- /dev/null +++ b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Preparing the Tokenizer for the dataset +Use the `process_asr_text_tokenizer.py` script under /scripts/tokenizers/ in order to prepare the tokenizer. + +# Manifest file example: +{"audio_filepath":"/data/audio.wav","duration":12.12,"text":"The transcript.","target_lang":"en-US"} + +```sh +python /scripts/tokenizers/process_asr_text_tokenizer.py \ + --manifest= + OR + --data_file= \ + --data_root="" \ + --vocab_size= \ + --tokenizer=<"spe" or "wpe"> \ + --no_lower_case \ + --spe_type=<"unigram", "bpe", "char" or "word"> \ + --spe_character_coverage=1.0 \ + --log +``` + +# Training the model +```sh +python speech_to_text_rnnt_bpe_prompt.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + model.tokenizer.dir= \ + model.tokenizer.type= \ + trainer.devices=-1 \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` + +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +""" + +import lightning.pytorch as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecRNNTBPEModelWithPrompt +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + + +@hydra_runner( + config_path="../conf/fastconformer/cache_aware_streaming/", + config_name="fastconformer_transducer_bpe_streaming_prompt.yaml", +) +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer)) + exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecRNNTBPEModelWithPrompt(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml new file mode 100644 index 000000000000..4fc45a047c2e --- /dev/null +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml @@ -0,0 +1,391 @@ +# Cache-aware streaming FastConformer-Transducer (RNNT-only) ASR model with prompt support +# Combines cache-aware streaming encoder with prompt-based multilingual capability +# This is the RNNT-only variant (no auxiliary CTC head). + +# You may find more detail: +# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer +# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer + +name: "FastConformer-Transducer-BPE-Prompt-Streaming" + +model: + sample_rate: 16000 + compute_eval_loss: false + log_prediction: true + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + initialize_prompt_feature: true + num_prompts: 128 + norm: None + prompt_dictionary: { + 'en-US': 0, + 'en': 0, + 'en-GB': 1, + 'enGB': 1, + 'es-ES': 2, + 'esES': 2, + 'es-US': 3, + 'es': 3, + 'zh-CN': 4, + 'zh-ZH': 4, + 'zh-TW': 5, + 'hi-IN': 6, + 'hi': 6, + 'hi-HI': 6, + 'ar-AR': 7, + 'ar': 7, + 'fr-FR': 8, + 'fr': 8, + 'de-DE': 9, + 'de': 9, + 'ja-JP': 10, + 'ja-JA': 10, + 'ru-RU': 11, + 'ru': 11, + 'pt-BR': 12, + 'pt-PT': 13, + 'pt': 13, + 'ko-KR': 14, + 'ko': 14, + 'ko-KO': 14, + 'it-IT': 15, + 'it': 15, + 'nl-NL': 16, + 'nl': 16, + 'pl-PL': 17, + 'pl': 17, + 'tr-TR': 18, + 'tr': 18, + 'uk-UA': 19, + 'uk': 19, + 'ro-RO': 20, + 'ro': 20, + 'el-GR': 21, + 'el': 21, + 'cs-CZ': 22, + 'cs': 22, + 'hu-HU': 23, + 'hu': 23, + 'sv-SE': 24, + 'sv': 24, + 'da-DK': 25, + 'da': 25, + 'fi-FI': 26, + 'fi': 26, + 'no-NO': 27, + 'no': 27, + 'nb-NO': 103, + 'nb': 103, + 'sk-SK': 28, + 'sk': 28, + 'hr-HR': 29, + 'hr': 29, + 'bg-BG': 30, + 'bg': 30, + 'lt-LT': 31, + 'lt': 31, + 'et-EE': 60, + 'et': 60, + 'lv-LV': 61, + 'lv': 61, + 'sl-SI': 62, + 'sl': 62, + 'th-TH': 32, + 'vi-VN': 33, + 'id-ID': 34, + 'ms-MY': 35, + 'bn-IN': 36, + 'ur-PK': 37, + 'fa-IR': 38, + 'ta-IN': 39, + 'te-IN': 40, + 'mr-IN': 41, + 'gu-IN': 42, + 'kn-IN': 43, + 'ml-IN': 44, + 'si-LK': 45, + 'ne-NP': 46, + 'km-KH': 47, + 'sw-KE': 48, + 'am-ET': 49, + 'ha-NG': 50, + 'zu-ZA': 51, + 'yo-NG': 52, + 'ig-NG': 53, + 'af-ZA': 54, + 'rw-RW': 55, + 'so-SO': 56, + 'ny-MW': 57, + 'ln-CD': 58, + 'or-KE': 59, + 'he-IL': 64, + 'ku-TR': 65, + 'az-AZ': 66, + 'ka-GE': 67, + 'hy-AM': 68, + 'uz-UZ': 69, + 'tg-TJ': 70, + 'ky-KG': 71, + 'qu-PE': 80, + 'ay-BO': 81, + 'gn-PY': 82, + 'nah-MX': 83, + 'mi-NZ': 96, + 'haw-US': 97, + 'sm-WS': 98, + 'to-TO': 99, + 'fr-CA': 100, + 'mt-MT': 102, + 'auto': 101 + } + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + use_lhotse: true + batch_duration: 400 + quadratic_duration: 15 + num_buckets: 30 + shuffle: true + num_workers: 8 + pin_memory: true + max_duration: 39.99 + min_duration: 0.1 + is_tarred: true + tarred_audio_filepaths: null + slice_length: 100 + bucketing_batch_size: null + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + lang_field: target_lang + training_mode: true + + # Per-dataset prompt mode — controls how language prompts are selected during training. + # The mode is set per data source via lhotse input_cfg tags: + # tags: { prompt_mode: unified } + # + # Supported prompt_mode values: + # "langID" — always pass the real language ID (use for AST / language-forced tasks) + # "auto" — always pass auto/language-agnostic prompt (use for code-switching) + # "unified" — randomly choose auto vs lang ID (use for multilingual ASR, default) + # + # unified_auto_ratio controls the probability of selecting auto in "unified" mode. + prompt_mode_field: prompt_mode + default_prompt_mode: unified + unified_auto_ratio: 0.5 + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 2 + shuffle: false + use_start_end_token: false + num_workers: 2 + pin_memory: true + batch_duration: null + use_lhotse: true + use_bucketing: false + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + training_mode: true + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + use_lhotse: true + use_bucketing: false + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + training_mode: false + + tokenizer: + dir: ??? + type: bpe + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 + n_layers: 42 + d_model: 1024 + use_bias: false + + subsampling: dw_striding + subsampling_factor: 8 + subsampling_conv_channels: 256 + causal_downsampling: true + + reduction: null + reduction_position: null + reduction_factor: 1 + + ff_expansion_factor: 4 + + self_attention_model: rel_pos + n_heads: 8 + att_context_size: [70, 6] + att_context_probs: null + att_context_style: chunked_limited + xscaling: false + untie_biases: true + pos_emb_max_len: 5000 + + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' + conv_context_size: causal + + dropout: 0.1 + dropout_pre_encoder: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 + + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null + random_state_sampling: false + blank_as_pad: true + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 2 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null + preserve_memory: false + + fuse_loss_wer: true + fused_batch_size: 2 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" + # Strip language-ID tags (e.g. ) from decoded output during inference. + strip_lang_tags: true + + greedy: + max_symbols: 10 + + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 + alsd_max_target_len: 2.0 + + interctc: + loss_weights: [] + apply_at_layers: [] + + loss: + loss_name: "default" + + warprnnt_numba_kwargs: + fastemit_lambda: 5e-3 + clamp: -1.0 + + variational_noise: + start_step: 0 + std: 0.0 + + optim: + name: adamw + lr: 2.0 + betas: [0.9, 0.98] + weight_decay: 1e-3 + + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 + num_nodes: 1 + max_epochs: -1 + max_steps: 500000 + val_check_interval: 0.5 + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 0.5 + precision: bf16 + log_every_n_steps: 100 + enable_progress_bar: True + num_sanity_val_steps: 0 + sync_batchnorm: true + enable_checkpointing: False + logger: false + benchmark: false + use_distributed_sampler: false + limit_train_batches: 1000 + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py new file mode 100644 index 000000000000..249248a75f4c --- /dev/null +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simplified Lhotse dataset that returns language ID indices instead of full prompt tensors. +The model creates the prompt tensor using the actual encoded length. +""" + +import random +from typing import Dict, Optional, Tuple + +import torch +import torch.utils.data +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors + +from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.utils import logging + + +class LhotseSpeechToTextBpeDatasetWithPromptIndex(torch.utils.data.Dataset): + """ + Simplified dataset class for speech-to-text with prompt support. + + Instead of computing full prompt tensors, this dataset returns just the + language ID index per sample. The model creates the prompt tensor using + the actual encoder output length, guaranteeing no size mismatch. + + Returns: + audio_signal: Audio waveform [B, T] + audio_signal_length: Audio lengths [B] + transcripts: Token IDs [B, T] + transcript_length: Token lengths [B] + prompt_indices: Language ID indices [B] (NOT full tensors) + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'audio_signal_length': NeuralType(tuple('B'), LengthsType()), + 'transcripts': NeuralType(('B', 'T'), LabelsType()), + 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'prompt_indices': NeuralType(tuple('B'), LabelsType()), # Just indices, not full tensors + } + + def __init__(self, tokenizer: TokenizerSpec, cfg: Dict) -> None: + super().__init__() + self.tokenizer = TokenizerWrapper(tokenizer) + self.load_audio = AudioSamples(fault_tolerant=True) + self.cfg = cfg + + # Load prompt dictionary from config + self.prompt_dict = cfg.get('prompt_dictionary') + if not self.prompt_dict: + raise ValueError("prompt_dictionary is required in config") + + self.num_prompts = cfg.get('num_prompts', 128) + + # Field to use for prompt key (default to 'target_lang') + self.prompt_field = cfg.get('prompt_field', 'target_lang') + + self.training_mode = cfg.get('training_mode', True) + + # Per-dataset prompt mode is read from cut.custom["prompt_mode"] at runtime. + # Supported values: + # "langID" — always pass the real language ID + # "auto" — always pass auto (101) + # "unified" — randomize: auto with probability unified_auto_ratio, else lang ID + # Set via lhotse input_cfg tags, e.g. tags: { prompt_mode: langID } + self.prompt_mode_field = cfg.get('prompt_mode_field', 'prompt_mode') + self.default_prompt_mode = cfg.get('default_prompt_mode', 'unified') + self.unified_auto_ratio = cfg.get('unified_auto_ratio', 0.5) + + # Index used for the language-agnostic / auto prompt + self.auto_index = self.prompt_dict.get('auto', 101) + + logging.info( + f"LhotseSpeechToTextBpeDatasetWithPromptIndex: " + f"default_prompt_mode={self.default_prompt_mode}, " + f"unified_auto_ratio={self.unified_auto_ratio}" + ) + + def _get_prompt_index(self, prompt_key: str) -> int: + """Maps prompt keys to indices using the prompt dictionary.""" + if prompt_key not in self.prompt_dict: + available_keys = list(self.prompt_dict.keys()) + raise ValueError( + f"Unknown prompt key: '{prompt_key}'. Available: {available_keys[:10]}{'...' if len(available_keys) > 10 else ''}" + ) + return self.prompt_dict[prompt_key] + + def _get_prompt_mode(self, cut) -> str: + """Resolve the prompt_mode for a cut from its custom tags.""" + if cut.custom is not None: + mode = cut.custom.get(self.prompt_mode_field) + if mode is not None: + return mode + return self.default_prompt_mode + + def _get_prompt_index_for_cut(self, cut) -> int: + """ + Determine the prompt index for a cut based on its prompt_mode tag. + + During inference (training_mode=False): always returns the real lang ID + regardless of prompt_mode. + + During training, behaviour depends on prompt_mode (set per-dataset via + lhotse input_cfg tags): + "langID" — always return the real language ID + "auto" — always return auto index (language-agnostic) + "unified" — return auto with probability unified_auto_ratio, + otherwise the real language ID + """ + if not self.training_mode: + return self._get_prompt_index(cut.supervisions[0].language) + + mode = self._get_prompt_mode(cut) + + if mode == 'langID': + return self._get_prompt_index(cut.supervisions[0].language) + elif mode == 'auto': + return self.auto_index + elif mode == 'unified': + if random.random() < self.unified_auto_ratio: + return self.auto_index + return self._get_prompt_index(cut.supervisions[0].language) + else: + logging.warning(f"Unknown prompt_mode '{mode}', falling back to unified") + if random.random() < self.unified_auto_ratio: + return self.auto_index + return self._get_prompt_index(cut.supervisions[0].language) + + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts] + + # Get prompt indices (just the language ID per sample, NOT full tensors) + prompt_indices = torch.tensor([self._get_prompt_index_for_cut(c) for c in cuts], dtype=torch.long) + + # Create final tensors + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) + tokens = collate_vectors(tokens, padding_value=0) + + return ( + audio, # Audio signal [B, T] + audio_lens, # Audio lengths [B] + tokens, # Text tokens [B, T] + token_lens, # Token lengths [B] + prompt_indices, # Language ID indices [B] - model creates full tensor + ) diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index cc9b3a74e1ea..67a99268175b 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -30,6 +30,7 @@ from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel # noqa: F401 from nemo.collections.asr.models.multitalker_asr_models import EncDecMultiTalkerRNNTBPEModel # noqa: F401 from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel # noqa: F401 +from nemo.collections.asr.models.rnnt_bpe_models_prompt import EncDecRNNTBPEModelWithPrompt # noqa: F401 from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel # noqa: F401 from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel # noqa: F401 from nemo.collections.asr.models.ssl_models import ( # noqa: F401 @@ -55,6 +56,7 @@ 'EncDecMultiTaskModel', 'EncDecMultiTalkerRNNTBPEModel', 'EncDecRNNTBPEModel', + 'EncDecRNNTBPEModelWithPrompt', 'EncDecRNNTModel', 'EncDecSpeakerLabelModel', 'EncDecTransfModelBPE', diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py index 992de84fc7a8..07f2b96b95f3 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import os from dataclasses import dataclass from math import ceil @@ -25,11 +24,12 @@ from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset, DALIOutputs from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset -from nemo.collections.asr.data.audio_to_text_lhotse_prompt import LhotseSpeechToTextBpeDatasetWithPrompt +from nemo.collections.asr.data.audio_to_text_lhotse_prompt_index import LhotseSpeechToTextBpeDatasetWithPromptIndex from nemo.collections.asr.metrics.bleu import BLEU from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, TranscribeConfig +from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig @@ -55,8 +55,8 @@ class HybridRNNTCTCPromptTranscribeConfig(TranscribeConfig): Configuration for Hybrid RNNT-CTC BPE Model with Prompt Transcription """ - target_lang: str = "en-US" - prompt_field: str = "lang" + target_lang: str = "auto" + prompt_field: str = "target_lang" class EncDecHybridRNNTCTCBPEModelWithPrompt(EncDecHybridRNNTCTCBPEModel, ASRTranscriptionMixin): @@ -108,9 +108,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup prompt settings - default to 128 prompts if not specified cfg.num_prompts = cfg.model_defaults.get('num_prompts', 128) - # Make sure prompt_dictionary exists if 'prompt_dictionary' not in cfg.model_defaults: - raise ValueError("No prompt_dictionary found in config.") + logging.warning( + "No prompt_dictionary in config; using empty dict " "(expected during checkpoint restoration)." + ) + cfg.model_defaults.prompt_dictionary = {} # Set subsampling_factor in a place accessible to the class self.subsampling_factor = cfg.get('subsampling_factor', 8) @@ -193,14 +195,80 @@ def initialize_prompt_feature(self): # setting the RNNT decoder as the default one self.cur_decoder = "rnnt" + # Streaming inference with language-ID prompt + + def set_inference_prompt(self, target_lang: str): + """ + Set the language prompt for streaming inference. + + Call this before ``conformer_stream_step`` to condition decoding on + a specific language, following the same pattern as + ``change_decoding_strategy``. + + Args: + target_lang: A key from the model's ``prompt_dictionary`` + (e.g. ``"en-US"``, ``"auto"``). + """ + prompt_dict = self.cfg.model_defaults.get('prompt_dictionary', {}) + if target_lang not in prompt_dict: + available = list(prompt_dict.keys()) + raise ValueError( + f"Unknown target language '{target_lang}'. " + f"Available: {available[:20]}{'...' if len(available) > 20 else ''}" + ) + self._inference_prompt_index = prompt_dict[target_lang] + logging.info(f"Inference prompt set to '{target_lang}' (index {self._inference_prompt_index})") + + def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: + """ + Inject the language-ID prompt into encoder output during streaming. + + ``encoded`` arrives as (B, D, T) from the encoder cache-aware step. + Returns the same shape after prompt concatenation + projection. + """ + if not self.concat or not hasattr(self, '_inference_prompt_index'): + return encoded + + encoded = encoded.transpose(1, 2) # (B, D, T) -> (B, T, D) + + batch_size, time_steps, _ = encoded.shape + prompt = torch.zeros( + batch_size, + time_steps, + self.num_prompts, + dtype=encoded.dtype, + device=encoded.device, + ) + idx = torch.full( + (batch_size,), + self._inference_prompt_index, + dtype=torch.long, + device=encoded.device, + ) + prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + + out_dtype = encoded.dtype + encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) + return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) + def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): if config.get('initialize_prompt_feature', True): - dataset = LhotseSpeechToTextBpeDatasetWithPrompt(tokenizer=self.tokenizer, cfg=config) - logging.info("Setting up Lhotse dataset with prompt support") + # Use index-based dataset - returns prompt indices instead of full tensors + # The model creates prompt tensors after encoding, guaranteeing no size mismatch + dataset_config = ( + OmegaConf.to_container(config, resolve=True) if isinstance(config, DictConfig) else dict(config) + ) + if hasattr(self, 'cfg') and 'encoder' in self.cfg: + dataset_config['encoder'] = ( + OmegaConf.to_container(self.cfg.encoder, resolve=True) + if isinstance(self.cfg.encoder, DictConfig) + else dict(self.cfg.encoder) + ) + dataset = LhotseSpeechToTextBpeDatasetWithPromptIndex(tokenizer=self.tokenizer, cfg=dataset_config) + logging.info("Setting up Lhotse dataset with prompt index support (model creates prompt tensors)") else: dataset = LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer) - logging.info("Setting up Lhotse dataset without prompt support") return get_lhotse_dataloader_from_config( config, global_rank=self.global_rank, @@ -317,32 +385,25 @@ def _transcribe_forward(self, batch: tuple[torch.Tensor, ...], trcfg: HybridRNNT Returns: The model's outputs that are processed by `_transcribe_output_processing()`. """ - # Handling DataLoader batch - should be a tuple of tensors - # Expected structure: (audio, audio_lens, tokens, token_lens, prompt_targets) - # For transcription, we may only have (audio, audio_lens) or (audio, audio_lens, ..., prompt_targets) audio, audio_lens = batch[0], batch[1] + prompt, prompt_indices = None, None + if len(batch) >= 5: - # Prompt provided by the dataloader (one-hot vectors) - prompt = batch[4] # This should be the prompt_targets from dataset - else: - # Prompt to be built dynamically. - prompt = None + prompt_or_indices = batch[4] + if prompt_or_indices.dim() == 1: + prompt_indices = prompt_or_indices + else: + prompt = prompt_or_indices batch_size = audio.shape[0] - if prompt is None: - # The dataloader provided only audio + audio_lens, so we need to construct - # the prompt as one-hot vectors dynamically using TranscribeConfig. + if prompt is None and prompt_indices is None: target_lang = trcfg.target_lang - - # Get prompt dictionary and num_prompts from model config prompt_dict = self.cfg.model_defaults.get('prompt_dictionary') - num_prompts = self.cfg.model_defaults.get('num_prompts', 128) if not prompt_dict: raise ValueError("Prompt dictionary is empty. Cannot create dynamic prompts.") - # Get the prompt index for the target language if target_lang not in prompt_dict: available_keys = list(prompt_dict.keys()) raise ValueError( @@ -350,26 +411,11 @@ def _transcribe_forward(self, batch: tuple[torch.Tensor, ...], trcfg: HybridRNNT ) prompt_id = prompt_dict[target_lang] + prompt_indices = torch.full((batch_size,), prompt_id, dtype=torch.long, device=audio.device) - # Preprocess audio to get the actual feature dimensions (like streaming does) - processed_signal, processed_signal_length = self.preprocessor(input_signal=audio, length=audio_lens) - - # Calculate exact hidden length using the same approach as streaming - time_length = processed_signal.shape[2] # Feature time dimension - subsampling_factor = self.cfg.get('subsampling_factor', 8) - hidden_length = math.ceil(time_length / subsampling_factor) - - # Create one-hot prompt tensor: (batch_size, time_steps, num_prompts) - prompt = torch.zeros(batch_size, hidden_length, num_prompts, dtype=torch.float32, device=audio.device) - prompt[:, :, prompt_id] = 1.0 # Set the target language prompt to 1 - - # Now call forward with preprocessed signal and prompt - encoded, encoded_len = self.forward( - processed_signal=processed_signal, processed_signal_length=processed_signal_length, prompt=prompt - ) - else: - # Prompt was provided, use normal forward path - encoded, encoded_len = self.forward(input_signal=audio, input_signal_length=audio_lens, prompt=prompt) + encoded, encoded_len = self.forward( + input_signal=audio, input_signal_length=audio_lens, prompt=prompt, prompt_indices=prompt_indices + ) # Prepare output dictionary based on decoder type if self.cur_decoder == "rnnt": @@ -458,7 +504,7 @@ def transcribe( # Create transcription config if not provided if override_config is None: # Extract target_lang from prompt or use default - target_lang = prompt.get('target_lang', 'en-US') + target_lang = prompt.get('target_lang', 'auto') prompt_field = prompt.get('prompt_field', 'target_lang') trcfg = HybridRNNTCTCPromptTranscribeConfig( @@ -507,7 +553,8 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), - "prompt": NeuralType(('B', 'T', 'D'), LabelsType()), + "prompt": NeuralType(('B', 'T', 'D'), LabelsType(), optional=True), + "prompt_indices": NeuralType(tuple('B'), LabelsType(), optional=True), } @property @@ -525,19 +572,12 @@ def forward( processed_signal=None, processed_signal_length=None, prompt=None, + prompt_indices=None, ): """ Forward pass of the model. Note that for RNNT Models, the forward pass of the model is a 3 step process, and this method only performs the first step - forward of the acoustic model. - Please refer to the `training_step` in order to see the full `forward` step for training - which - performs the forward of the acoustic model, the prediction network and then the joint network. - Finally, it computes the loss and possibly compute the detokenized text via the `decoding` step. - - Please refer to the `validation_step` in order to see the full `forward` step for inference - which - performs the forward of the acoustic model, the prediction network and then the joint network. - Finally, it computes the decoded tokens via the `decoding` step and possibly compute the batch metrics. - Args: input_signal: Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents timesteps, with 1 second of audio represented as @@ -548,13 +588,16 @@ def forward( of shape (B, D, T) that has undergone processing via some DALI preprocessor. processed_signal_length: Vector of length B, that contains the individual lengths of the processed audio sequences. - prompt: Tensor that represents the prompt embeddings, - of shape (B, T, D) where D is the number of supported prompts. - Used for prompt-conditioned encoding via concatenation with acoustic features. + prompt: (backward-compatible) Pre-built one-hot prompt tensor of shape [B, T, D]. + If provided, used directly (trimmed to encoder length if needed). + prompt_indices: Tensor of shape [B] containing language ID indices per sample. + The model creates the prompt tensor after encoding using the actual + encoder output length, guaranteeing no size mismatch. + Ignored if ``prompt`` is also provided. Returns: A tuple of 2 elements - - 1) The log probabilities tensor of shape [B, T, D]. + 1) The encoded tensor of shape [B, D, T]. 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. """ has_input_signal = input_signal is not None and input_signal_length is not None @@ -579,17 +622,23 @@ def forward( encoded = torch.transpose(encoded, 1, 2) # B * D * T -> B * T * D if self.concat: - if prompt.shape[1] > encoded.shape[1]: - prompt = prompt[:, : encoded.shape[1], :] - out_dtype = encoded.dtype # this is dtype, which the decoder previously got from encoder + if prompt is not None: + # Backward-compatible path: caller provided a pre-built [B, T, D] one-hot tensor + if prompt.shape[1] > encoded.shape[1]: + prompt = prompt[:, : encoded.shape[1], :] + elif prompt_indices is not None: + # New path: build one-hot from per-sample language ID indices + batch_size = encoded.shape[0] + time_steps = encoded.shape[1] + num_prompts = self.num_prompts + prompt = torch.zeros(batch_size, time_steps, num_prompts, dtype=encoded.dtype, device=encoded.device) + prompt.scatter_(2, prompt_indices.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + else: + raise ValueError("Either prompt or prompt_indices must be provided when concat mode is enabled.") - # Concatenate encoded states with prompt + out_dtype = encoded.dtype concat_enc_states = torch.cat([encoded, prompt], dim=-1) - - # Apply joint projection - encoded = self.prompt_kernel(concat_enc_states).to( - out_dtype - ) # cast: unexpectedly without cast dtype is different from out_dtype + encoded = self.prompt_kernel(concat_enc_states).to(out_dtype) encoded = torch.transpose(encoded, 1, 2) # B * T * D -> B * D * T return encoded, encoded_len @@ -602,13 +651,21 @@ def training_step(self, batch, batch_nb): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len, prompt = batch + signal, signal_len, transcript, transcript_len, prompt_or_indices = batch + + # Detect whether batch[4] is old-style prompt [B, T, D] or new-style indices [B] + prompt, prompt_indices = None, None + if prompt_or_indices.dim() == 1: + prompt_indices = prompt_or_indices + else: + prompt = prompt_or_indices - # forward() only performs encoder forward if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt=prompt) + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt=prompt, prompt_indices=prompt_indices + ) del signal # During training, loss must be computed, so decoder forward is necessary @@ -714,13 +771,20 @@ def training_step(self, batch, batch_nb): return {'loss': loss_value} def predict_step(self, batch, batch_idx, dataloader_idx=0): - signal, signal_len, transcript, transcript_len, prompt = batch + signal, signal_len, transcript, transcript_len, prompt_or_indices = batch + + prompt, prompt_indices = None, None + if prompt_or_indices.dim() == 1: + prompt_indices = prompt_or_indices + else: + prompt = prompt_or_indices - # forward() only performs encoder forward if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt=prompt) + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt=prompt, prompt_indices=prompt_indices + ) del signal if self.cur_decoder == 'rnnt': @@ -744,13 +808,20 @@ def validation_pass(self, batch, batch_idx, dataloader_idx): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len, prompt = batch + signal, signal_len, transcript, transcript_len, prompt_or_indices = batch + + prompt, prompt_indices = None, None + if prompt_or_indices.dim() == 1: + prompt_indices = prompt_or_indices + else: + prompt = prompt_or_indices - # forward() only performs encoder forward if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt=prompt) + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt=prompt, prompt_indices=prompt_indices + ) del signal tensorboard_logs = {} diff --git a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py new file mode 100644 index 000000000000..237fe8ba30b2 --- /dev/null +++ b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py @@ -0,0 +1,719 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset, DALIOutputs +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset +from nemo.collections.asr.data.audio_to_text_lhotse_prompt_index import LhotseSpeechToTextBpeDatasetWithPromptIndex +from nemo.collections.asr.metrics.wer import WER +from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel +from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, TranscribeConfig +from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder +from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTBPEDecoding +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.classes.mixins import AccessMixin +from nemo.core.neural_types import ( + AcousticEncodedRepresentation, + AudioSignal, + LabelsType, + LengthsType, + NeuralType, + SpectrogramType, +) +from nemo.utils import logging, model_utils + + +@dataclass +class RNNTPromptTranscribeConfig(TranscribeConfig): + """Transcription configuration for RNNT BPE Model with Prompt conditioning.""" + + target_lang: str = "auto" + prompt_field: str = "target_lang" + + +class EncDecRNNTBPEModelWithPrompt(EncDecRNNTBPEModel, ASRTranscriptionMixin): + """Encoder-decoder RNNT model with subword tokenization and prompt conditioning. + + This is the RNNT-only variant (no auxiliary CTC head) of the prompt-aware + cache-aware streaming model. The prompt mechanism concatenates a language-ID + one-hot vector to the encoder output and projects back to the original + dimension, allowing the decoder to condition on the target language. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + if not isinstance(cfg, DictConfig): + cfg = OmegaConf.create(cfg) + + self._setup_tokenizer(cfg.tokenizer) + + vocabulary = self.tokenizer.tokenizer.get_vocab() + + with open_dict(cfg): + cfg.labels = ListConfig(list(vocabulary)) + + with open_dict(cfg.decoder): + cfg.decoder.vocab_size = len(vocabulary) + + with open_dict(cfg.joint): + cfg.joint.num_classes = len(vocabulary) + cfg.joint.vocabulary = ListConfig(list(vocabulary)) + cfg.joint.jointnet.encoder_hidden = cfg.model_defaults.enc_hidden + cfg.joint.jointnet.pred_hidden = cfg.model_defaults.pred_hidden + + with open_dict(cfg): + cfg.num_prompts = cfg.model_defaults.get('num_prompts', 128) + + if 'prompt_dictionary' not in cfg.model_defaults: + logging.warning( + "No prompt_dictionary in config; using empty dict " "(expected during checkpoint restoration)." + ) + cfg.model_defaults.prompt_dictionary = {} + + self.subsampling_factor = cfg.get('subsampling_factor', 8) + + super().__init__(cfg=cfg, trainer=trainer) + + self.concat = False + + if self.cfg.model_defaults.get('initialize_prompt_feature', False): + self.initialize_prompt_feature() + + @classmethod + def restore_from( + cls, + restore_path, + override_config_path=None, + map_location=None, + strict=True, + return_config=False, + save_restore_connector=None, + trainer=None, + validate_access_integrity=True, + ): + """Delegate to base EncDecRNNTBPEModel to avoid subclass substitution. + + NeMo's from_config_dict checks issubclass(cls, checkpoint_target_cls) + and, when True, replaces the checkpoint class with cls. Because this + class is a direct subclass of the checkpoint's target class + (EncDecRNNTBPEModel), the substitution would try to fully instantiate + EncDecRNNTBPEModelWithPrompt with the checkpoint config — which lacks + prompt_dictionary and hangs. Delegating to the parent class keeps + cls == EncDecRNNTBPEModel so the checkpoint is loaded with its own + class, matching the behaviour that naturally occurs for hybrid models. + """ + return EncDecRNNTBPEModel.restore_from( + restore_path=restore_path, + override_config_path=override_config_path, + map_location=map_location, + strict=strict, + return_config=return_config, + save_restore_connector=save_restore_connector, + trainer=trainer, + validate_access_integrity=validate_access_integrity, + ) + + def initialize_prompt_feature(self): + """Initialize model components for prompt feature via concatenation.""" + logging.info("Model with prompt feature has been initialized (RNNT-only)") + + self.concat = True + self.num_prompts = self.cfg.get('num_prompts', 128) + + proj_in_size = self.num_prompts + self._cfg.model_defaults.enc_hidden + proj_out_size = self._cfg.model_defaults.enc_hidden + + self.prompt_kernel = torch.nn.Sequential( + torch.nn.Linear(proj_in_size, proj_out_size * 2), + torch.nn.ReLU(), + torch.nn.Linear(proj_out_size * 2, proj_out_size), + ) + + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, + ) + + self.wer = WER( + decoding=self.decoding, + batch_dim_index=0, + use_cer=self.cfg.get('use_cer', False), + log_prediction=self.cfg.get('log_prediction', True), + dist_sync_on_step=True, + ) + + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + # Streaming inference with language-ID prompt + def set_inference_prompt(self, target_lang: str): + """ + Set the language prompt for streaming inference. + + Call this before ``conformer_stream_step`` to condition decoding on + a specific language, following the same pattern as + ``change_decoding_strategy``. + + Args: + target_lang: A key from the model's ``prompt_dictionary`` + (e.g. ``"en-US"``, ``"auto"``). + """ + prompt_dict = self.cfg.model_defaults.get('prompt_dictionary', {}) + if target_lang not in prompt_dict: + available = list(prompt_dict.keys()) + raise ValueError( + f"Unknown target language '{target_lang}'. " + f"Available: {available[:20]}{'...' if len(available) > 20 else ''}" + ) + self._inference_prompt_index = prompt_dict[target_lang] + logging.info(f"Inference prompt set to '{target_lang}' (index {self._inference_prompt_index})") + + def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: + """ + Inject the language-ID prompt into encoder output during streaming. + + ``encoded`` arrives as (B, D, T) from the encoder cache-aware step. + Returns the same shape after prompt concatenation + projection. + """ + if not self.concat or not hasattr(self, '_inference_prompt_index'): + return encoded + + encoded = encoded.transpose(1, 2) # (B, D, T) -> (B, T, D) + + batch_size, time_steps, _ = encoded.shape + prompt = torch.zeros( + batch_size, + time_steps, + self.num_prompts, + dtype=encoded.dtype, + device=encoded.device, + ) + idx = torch.full( + (batch_size,), + self._inference_prompt_index, + dtype=torch.long, + device=encoded.device, + ) + prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + + out_dtype = encoded.dtype + encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) + return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) + + # Data loading + def _setup_dataloader_from_config(self, config: Optional[Dict]): + if config.get("use_lhotse"): + if config.get('initialize_prompt_feature', True): + dataset_config = ( + OmegaConf.to_container(config, resolve=True) if isinstance(config, DictConfig) else dict(config) + ) + if hasattr(self, 'cfg') and 'encoder' in self.cfg: + dataset_config['encoder'] = ( + OmegaConf.to_container(self.cfg.encoder, resolve=True) + if isinstance(self.cfg.encoder, DictConfig) + else dict(self.cfg.encoder) + ) + dataset = LhotseSpeechToTextBpeDatasetWithPromptIndex(tokenizer=self.tokenizer, cfg=dataset_config) + logging.info( + "Setting up Lhotse dataset with prompt index support (RNNT-only model creates prompt tensors)" + ) + else: + dataset = LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer) + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=dataset, + tokenizer=self.tokenizer, + ) + + dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( + config=config, + local_rank=self.local_rank, + global_rank=self.global_rank, + world_size=self.world_size, + tokenizer=self.tokenizer, + preprocessor_cfg=self.cfg.get("preprocessor", None), + ) + + if dataset is None: + return None + + if isinstance(dataset, AudioToBPEDALIDataset): + return dataset + + shuffle = config['shuffle'] + if isinstance(dataset, torch.utils.data.IterableDataset): + shuffle = False + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + collate_fn = dataset.datasets[0].collate_fn + else: + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + if 'manifest_filepath' in config: + manifest_filepath = config['manifest_filepath'] + batch_size = config['batch_size'] + else: + manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + + target_lang = config.get('target_lang', 'en-US') + + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': self.preprocessor._sample_rate, + 'labels': self.joint.vocabulary, + 'batch_size': batch_size, + 'trim_silence': False, + 'shuffle': False, + 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), + 'pin_memory': True, + 'use_lhotse': config.get('use_lhotse', True), + 'use_bucketing': False, + 'drop_last': False, + 'prompt_field': config.get('prompt_field', 'target_lang'), + 'initialize_prompt_feature': True, + 'prompt_dictionary': self.cfg.model_defaults.get('prompt_dictionary'), + 'num_prompts': self.cfg.model_defaults.get('num_prompts', 128), + 'subsampling_factor': self.cfg.get('subsampling_factor', 8), + 'default_lang': target_lang, + 'window_stride': self.cfg.preprocessor.get('window_stride', 0.01), + } + + if config.get("augmentor"): + dl_config['augmentor'] = config.get("augmentor") + + return self._setup_dataloader_from_config(config=DictConfig(dl_config)) + + def setup_training_data(self, train_data_config: Optional[DictConfig]): + self._update_dataset_config(dataset_name='train', config=train_data_config) + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + self._update_dataset_config(dataset_name='validation', config=val_data_config) + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + self._update_dataset_config(dataset_name='test', config=test_data_config) + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "prompt_indices": NeuralType(tuple('B'), LabelsType()), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + prompt_indices=None, + ): + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) is False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, + length=input_signal_length, + ) + + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + encoded = torch.transpose(encoded, 1, 2) # B x D x T -> B x T x D + + if self.concat: + if prompt_indices is None: + raise ValueError("prompt_indices must be provided when concat mode is enabled.") + + batch_size = encoded.shape[0] + time_steps = encoded.shape[1] + num_prompts = self.num_prompts + + prompt = torch.zeros(batch_size, time_steps, num_prompts, dtype=encoded.dtype, device=encoded.device) + prompt.scatter_(2, prompt_indices.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + + out_dtype = encoded.dtype + concat_enc_states = torch.cat([encoded, prompt], dim=-1) + encoded = self.prompt_kernel(concat_enc_states).to(out_dtype) + + encoded = torch.transpose(encoded, 1, 2) # B x T x D -> B x D x T + return encoded, encoded_len + + def training_step(self, batch, batch_nb): + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + signal, signal_len, transcript, transcript_len, prompt_indices = batch + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices + ) + del signal + + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + sample_id = self._trainer.global_step + else: + log_every_n_steps = 1 + sample_id = batch_nb + + if not self.joint.fuse_loss_wer: + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + loss_value = self.add_auxiliary_losses(loss_value) + + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if (sample_id + 1) % log_every_n_steps == 0: + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + _, scores, words = self.wer.compute() + self.wer.reset() + tensorboard_logs.update({'training_batch_wer': scores.float() / words}) + + else: + if (sample_id + 1) % log_every_n_steps == 0: + compute_wer = True + else: + compute_wer = False + + loss_value, wer, _, _ = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoder, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=transcript_len, + compute_wer=compute_wer, + ) + loss_value = self.add_auxiliary_losses(loss_value) + + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if compute_wer: + tensorboard_logs.update({'training_batch_wer': wer}) + + self.log_dict(tensorboard_logs) + + if self._optim_normalize_joint_txu: + self._optim_normalize_txu = [encoded_len.max(), transcript_len.max()] + + return {'loss': loss_value} + + def validation_pass(self, batch, batch_idx, dataloader_idx=0): + signal, signal_len, transcript, transcript_len, prompt_indices = batch + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices + ) + del signal + + tensorboard_logs = {} + + if not self.joint.fuse_loss_wer: + if self.compute_eval_loss: + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + tensorboard_logs['val_loss'] = loss_value + + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + + else: + compute_wer = True + + if self.compute_eval_loss: + decoded, target_len, states = self.decoder(targets=transcript, target_length=transcript_len) + else: + decoded = None + target_len = transcript_len + + loss_value, wer, wer_num, wer_denom = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoded, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=target_len, + compute_wer=compute_wer, + ) + + if loss_value is not None: + tensorboard_logs['val_loss'] = loss_value + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return tensorboard_logs + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + tensorboard_logs = self.validation_pass(batch, batch_idx, dataloader_idx) + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(tensorboard_logs) + else: + self.validation_step_outputs.append(tensorboard_logs) + return tensorboard_logs + + def test_step(self, batch, batch_idx, dataloader_idx=0): + logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) + test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(test_logs) + else: + self.test_step_outputs.append(test_logs) + return test_logs + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + signal, signal_len, transcript, transcript_len, prompt_indices = batch + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices + ) + del signal + + best_hyp = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False + ) + + batch_size = signal_len.shape[0] + sample_id = torch.arange(batch_idx * batch_size, (batch_idx + 1) * batch_size).cpu().detach().numpy() + + return list(zip(sample_id, best_hyp)) + + def _transcribe_forward(self, batch, trcfg: RNNTPromptTranscribeConfig) -> dict: + audio, audio_lens = batch[0], batch[1] + if len(batch) >= 5: + prompt_indices = batch[4] + else: + prompt_indices = None + + batch_size = audio.shape[0] + + if prompt_indices is None: + target_lang = trcfg.target_lang + prompt_dict = self.cfg.model_defaults.get('prompt_dictionary') + + if not prompt_dict: + raise ValueError("Prompt dictionary is empty. Cannot create dynamic prompts.") + + if target_lang not in prompt_dict: + available_keys = list(prompt_dict.keys()) + raise ValueError( + f"Unknown target language: '{target_lang}'. " + f"Available languages: {available_keys[:10]}{'...' if len(available_keys) > 10 else ''}" + ) + + prompt_id = prompt_dict[target_lang] + prompt_indices = torch.full((batch_size,), prompt_id, dtype=torch.long, device=audio.device) + + encoded, encoded_len = self.forward( + input_signal=audio, input_signal_length=audio_lens, prompt_indices=prompt_indices + ) + + return dict(encoded=encoded, encoded_len=encoded_len) + + @torch.no_grad() + def transcribe( + self, + audio: List[str], + batch_size: int = 4, + return_hypotheses: bool = False, + partial_hypothesis: Optional[List['Hypothesis']] = None, + num_workers: int = 0, + channel_selector: Optional[ChannelSelectorType] = None, + augmentor: DictConfig = None, + verbose: bool = True, + timestamps: Optional[bool] = None, + override_config: Optional[RNNTPromptTranscribeConfig] = None, + **prompt, + ) -> TranscriptionReturnType: + if timestamps is not None: + decoding_cfg = self.cfg.decoding + if timestamps or (override_config is not None and override_config.timestamps): + return_hypotheses = True + with open_dict(decoding_cfg): + decoding_cfg.compute_timestamps = True + decoding_cfg.preserve_alignments = True + else: + with open_dict(decoding_cfg): + decoding_cfg.compute_timestamps = False + decoding_cfg.preserve_alignments = False + self.change_decoding_strategy(decoding_cfg, verbose=False) + + if override_config is None: + target_lang = prompt.get('target_lang', 'auto') + prompt_field = prompt.get('prompt_field', 'target_lang') + + trcfg = RNNTPromptTranscribeConfig( + batch_size=batch_size, + return_hypotheses=return_hypotheses, + num_workers=num_workers, + channel_selector=channel_selector, + augmentor=augmentor, + verbose=verbose, + timestamps=timestamps, + target_lang=target_lang, + prompt_field=prompt_field, + ) + else: + if not isinstance(override_config, RNNTPromptTranscribeConfig): + raise ValueError( + f"override_config must be of type {RNNTPromptTranscribeConfig}, " + f"but got {type(override_config)}" + ) + trcfg = override_config + + return super().transcribe( + audio=audio, + batch_size=batch_size, + return_hypotheses=return_hypotheses, + partial_hypothesis=partial_hypothesis, + num_workers=num_workers, + channel_selector=channel_selector, + augmentor=augmentor, + verbose=verbose, + timestamps=timestamps, + override_config=trcfg, + ) + + @classmethod + def get_transcribe_config(cls) -> RNNTPromptTranscribeConfig: + return RNNTPromptTranscribeConfig() + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + return None diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index af973be3cc4c..a83467fdbb96 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -588,6 +588,11 @@ def change_subsampling_conv_chunking_factor( if update_config: with open_dict(self.cfg): self.cfg.encoder.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + def _apply_prompt_to_encoded(self, encoded: Tensor) -> Tensor: + """Hook for prompt-conditioned subclasses to inject a language prompt + into the encoder output. Default: no-op.""" + return encoded def conformer_stream_step( self, @@ -661,6 +666,8 @@ def conformer_stream_step( bypass_pre_encode=bypass_pre_encode, ) + encoded = self._apply_prompt_to_encoded(encoded) + if isinstance(self, asr_models.EncDecCTCModel) or ( isinstance(self, asr_models.EncDecHybridRNNTCTCModel) and self.cur_decoder == "ctc" ): diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index c9a0989d1022..be3ffcb25f41 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -329,6 +329,10 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu punct_pattern = '|'.join([re.escape(p) for p in self.supported_punctuation]) self.space_before_punct_pattern = re.compile(r'(\s)(' + punct_pattern + ')') + self.strip_lang_tags = self.cfg.get('strip_lang_tags', False) + if self.strip_lang_tags: + self.lang_tag_pattern = re.compile(r'\s*<[a-z]{2}-[A-Z]{2}>') + # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) @@ -681,6 +685,11 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu # Update the joint fused batch size or disable it entirely if needed. self.update_joint_fused_batch_size() + + def set_strip_lang_tags(self, strip_lang_tags: bool): + if strip_lang_tags: + logging.info("Setting strip_lang_tags to True and defined lang_tag_pattern to ") + self.lang_tag_pattern = re.compile(r'\s*<[a-z]{2}-[A-Z]{2}>') @abstractproperty def tokenizer_type(self): @@ -949,10 +958,13 @@ def decode_ids_to_str(self, tokens: List[int]) -> str: def decode_tokens_to_str_with_strip_punctuation(self, tokens: List[int]) -> str: """ Decodes a list of tokens to a string and removes a space before supported punctuation marks. + Optionally strips language-ID tags (e.g. ````) when ``strip_lang_tags`` is enabled. """ text = self.decode_ids_to_str(tokens) if self.supported_punctuation: text = self.space_before_punct_pattern.sub(r'\2', text) + if self.strip_lang_tags: + text = self.lang_tag_pattern.sub('', text).strip() return text def update_joint_fused_batch_size(self): @@ -1855,6 +1867,10 @@ class RNNTDecodingConfig: # config for multiblank decoding. big_blank_durations: Optional[List[int]] = field(default_factory=list) + # Strip language-ID tags (e.g. ) from decoded output. + # Enable for prompt-conditioned models that emit locale tags after punctuation. + strip_lang_tags: bool = False + @dataclass class RNNTBPEDecodingConfig(RNNTDecodingConfig):