diff --git a/examples/asr/conf/ssl/nest/nest_fast-conformer-v2-xlarge.yaml b/examples/asr/conf/ssl/nest/nest_fast-conformer-v2-xlarge.yaml new file mode 100644 index 000000000000..8bc55043d156 --- /dev/null +++ b/examples/asr/conf/ssl/nest/nest_fast-conformer-v2-xlarge.yaml @@ -0,0 +1,249 @@ +# This config contains the default values for self-supervised pre-training of a FastConformer model +# +# Here are the recommended configs for different variants of FastConformer, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+----------------+--------------+------------+---------+ +# | Model | d_model | n_heads | n_layers |conv_kernel_size| weight_decay | xscaling | use_bias| +# +==============+=========+========+===========+================+==============+============+=========+ +# | Small (14M) | 176 | 4 | 16 | 9 | 0.0 | True | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+---------+ +# | Medium (32M) | 256 | 4 | 16 | 9 | 1e-3 | True | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+---------+ +# | Large (120M) | 512 | 8 | 17 | 9 | 1e-3 | True | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+---------+ +# | XLarge (616M)| 1024 | 8 | 24 | 9 | 1e-3 | False | False | +# +--------------+---------+--------+-----------+----------------+--------------+------------+---------+ +# | XXLarge(1.2B)| 1024 | 8 | 42 | 5 | 1e-3 | False | False | +# +--------------------------------------------------------------+--------------+------------+---------+ + + +name: "SSL-NEST-FastConformer-XL" + +model: + sample_rate: 16000 + num_classes: 8192 + num_books: 1 + code_dim: 16 + squeeze_single: false + mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv'] + + train_ds: + input_cfg: null + manifest_filepath: null # path to training manifest, can be a string or list of strings + noise_manifest: null # the manifest for noise data, can be a string or list of strings + sample_rate: ${model.sample_rate} + batch_size: null + batch_duration: null + shuffle: true + num_workers: 8 + pin_memory: true + max_duration: 60.0 + min_duration: 1.0 + drop_last: true + skip_missing_manifest_entries: true + defer_setup: true + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 # prob of activating the augmentation + noise_ratio: 0.5 # prob of applying noise aug, otherwise apply speech augmentation + min_r_speech: 10.0 # min SNR when applying speech augmentation + max_r_speech: 20.0 # max SNR when applying speech augmentation + min_r_noise: -5.0 # min SNR when applying noise augmentation + max_r_noise: 20.0 # max SNR when applying noise augmentation + min_mix_rate: 0.3 # min ratio of the input audio that would be augmented + max_mix_rate: 0.6 # max ratio of the input audio that would be augmented + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + validation_ds: + manifest_filepath: null + noise_manifest: null + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + max_duration: 60.0 + min_duration: 1.0 + defer_setup: true + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 + noise_ratio: 0.5 + min_r_speech: 10.0 # min SNR when applying speech augmentation + max_r_speech: 20.0 # max SNR when applying speech augmentation + min_r_noise: -5.0 # min SNR when applying noise augmentation + max_r_noise: 20.0 # max SNR when applying noise augmentation + min_mix_rate: 0.3 # min ratio of the input audio that would be augmented + max_mix_rate: 0.6 # max ratio of the input audio that would be augmented + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 8 + pad_value: 0.0 + + # spec_augment is not actually used, just to avoid init error + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 0 # set to zero to disable it + time_masks: 0 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + masking: + _target_: nemo.collections.asr.modules.RandomBlockMasking + block_size: 40 # for pre_conv masking, 10ms per frame, 400ms per block with block_size=40 + mask_prob: 0.01 # for allow_overlap=True, this means the mask prob for each frame; otherwise it means the overall masked proportion + feat_in: ${model.preprocessor.features} + freeze: true + allow_overlap: true + + quantizer: + _target_: nemo.collections.asr.modules.RandomProjectionVectorQuantizer + feat_in: ${model.preprocessor.features} + code_dim: ${model.code_dim} + num_books: ${model.num_books} + num_classes: ${model.num_classes} + dist_fn: "l2" # choices=["l2", "cosine"] + freeze: true + squeeze_single: ${model.squeeze_single} + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 24 + d_model: 1024 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: false # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.MultiSoftmaxDecoder + feat_in: ${model.encoder.d_model} + num_classes: ${model.num_classes} + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + use_bias: true + + loss: + _target_: nemo.collections.asr.losses.MultiMLMLoss + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio for 'pre_conv', 1 for 'post_conv' + mask_threshold: 0.8 + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 25000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 500000 # computed at runtime if not set + val_check_interval: 2500 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 1 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}' + + + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/ssl/nest/nest_streaming_fast-conformer-v2-xlarge.yaml b/examples/asr/conf/ssl/nest/nest_streaming_fast-conformer-v2-xlarge.yaml new file mode 100644 index 000000000000..ad5702127058 --- /dev/null +++ b/examples/asr/conf/ssl/nest/nest_streaming_fast-conformer-v2-xlarge.yaml @@ -0,0 +1,249 @@ +# This config contains the default values for self-supervised pre-training of a FastConformer model +# +# Here are the recommended configs for different variants of FastConformer, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+----------------+--------------+------------+---------+ +# | Model | d_model | n_heads | n_layers |conv_kernel_size| weight_decay | xscaling | use_bias| +# +==============+=========+========+===========+================+==============+============+=========+ +# | Small (14M) | 176 | 4 | 16 | 9 | 0.0 | True | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+---------+ +# | Medium (32M) | 256 | 4 | 16 | 9 | 1e-3 | True | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+---------+ +# | Large (120M) | 512 | 8 | 17 | 9 | 1e-3 | True | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+---------+ +# | XLarge (616M)| 1024 | 8 | 24 | 9 | 1e-3 | False | False | +# +--------------+---------+--------+-----------+----------------+--------------+------------+---------+ +# | XXLarge(1.2B)| 1024 | 8 | 42 | 5 | 1e-3 | False | False | +# +--------------------------------------------------------------+--------------+------------+---------+ + + +name: "SSL-NEST-FastConformer-XL" + +model: + sample_rate: 16000 + num_classes: 8192 + num_books: 1 + code_dim: 16 + squeeze_single: false + mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv'] + + train_ds: + input_cfg: null + manifest_filepath: null # path to training manifest, can be a string or list of strings + noise_manifest: null # the manifest for noise data, can be a string or list of strings + sample_rate: ${model.sample_rate} + batch_size: null + batch_duration: null + shuffle: true + num_workers: 8 + pin_memory: true + max_duration: 60.0 + min_duration: 1.0 + drop_last: true + skip_missing_manifest_entries: true + defer_setup: true + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 # prob of activating the augmentation + noise_ratio: 0.5 # prob of applying noise aug, otherwise apply speech augmentation + min_r_speech: 10.0 # min SNR when applying speech augmentation + max_r_speech: 20.0 # max SNR when applying speech augmentation + min_r_noise: -5.0 # min SNR when applying noise augmentation + max_r_noise: 20.0 # max SNR when applying noise augmentation + min_mix_rate: 0.3 # min ratio of the input audio that would be augmented + max_mix_rate: 0.6 # max ratio of the input audio that would be augmented + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + validation_ds: + manifest_filepath: null + noise_manifest: null + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + max_duration: 60.0 + min_duration: 1.0 + defer_setup: true + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 + noise_ratio: 0.5 + min_r_speech: 10.0 # min SNR when applying speech augmentation + max_r_speech: 20.0 # max SNR when applying speech augmentation + min_r_noise: -5.0 # min SNR when applying noise augmentation + max_r_noise: 20.0 # max SNR when applying noise augmentation + min_mix_rate: 0.3 # min ratio of the input audio that would be augmented + max_mix_rate: 0.6 # max ratio of the input audio that would be augmented + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 8 + pad_value: 0.0 + + # spec_augment is not actually used, just to avoid init error + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 0 # set to zero to disable it + time_masks: 0 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + masking: + _target_: nemo.collections.asr.modules.RandomBlockMasking + block_size: 40 # for pre_conv masking, 10ms per frame, 400ms per block with block_size=40 + mask_prob: 0.01 # for allow_overlap=True, this means the mask prob for each frame; otherwise it means the overall masked proportion + feat_in: ${model.preprocessor.features} + freeze: true + allow_overlap: true + + quantizer: + _target_: nemo.collections.asr.modules.RandomProjectionVectorQuantizer + feat_in: ${model.preprocessor.features} + code_dim: ${model.code_dim} + num_books: ${model.num_books} + num_classes: ${model.num_classes} + dist_fn: "l2" # choices=["l2", "cosine"] + freeze: true + squeeze_single: ${model.squeeze_single} + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 24 + d_model: 1024 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: true + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [70, 13] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: false # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: "causal" + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.MultiSoftmaxDecoder + feat_in: ${model.encoder.d_model} + num_classes: ${model.num_classes} + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + use_bias: true + + loss: + _target_: nemo.collections.asr.losses.MultiMLMLoss + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio for 'pre_conv', 1 for 'post_conv' + mask_threshold: 0.8 + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 25000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 500000 # computed at runtime if not set + val_check_interval: 2500 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 1 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}' + + + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/ssl/nest/nest_streaming_transformer_xlarge.yaml b/examples/asr/conf/ssl/nest/nest_streaming_transformer_xlarge.yaml new file mode 100644 index 000000000000..2c6a49ccbdf4 --- /dev/null +++ b/examples/asr/conf/ssl/nest/nest_streaming_transformer_xlarge.yaml @@ -0,0 +1,249 @@ +# This config contains the default values for self-supervised pre-training of a streaming ASR Transformer model +# +# Here are the recommended configs for different variants of ASR Transformer, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+--------------+------------+---------+ +# | Model | d_model | n_heads | n_layers | weight_decay | xscaling | use_bias| +# +==============+=========+========+===========+==============+============+=========+ +# | Large (119M) | 512 | 8 | 33 | 1e-3 | False | True | +# +--------------+---------+--------+-----------+--------------+------------+---------+ +# | XLarge (612M)| 1024 | 8 | 44 | 1e-3 | False | False | +# +--------------+---------+--------+-----------+--------------+------------+---------+ +# | XXLarge(1.2B)| 1024 | 8 | 88 | 1e-3 | False | False | +# +--------------+---------+--------+-----------+--------------+------------+---------+ + + +name: "SSL-NEST-Streaming-Transformer-XL" + +model: + sample_rate: 16000 + num_classes: 8192 + num_books: 1 + code_dim: 16 + squeeze_single: false + mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv'] + + train_ds: + manifest_filepath: ??? # path to training manifest, can be a string or list of strings + noise_manifest: ??? # the manifest for noise data, can be a string or list of strings + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 60.0 + min_duration: 1.0 + drop_last: true + is_concat: false + concat_sampling_technique: temperature + concat_sampling_temperature: 1.0 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 # prob of activating the augmentation + noise_ratio: 0.5 # prob of applying noise aug, otherwise apply speech augmentation + min_r_speech: -5.0 # min energy ratio when applying speech augmentation + max_r_speech: 5.0 # max energy ratio when applying speech augmentation + min_r_noise: -5.0 # min energy ratio when applying noise augmentation + max_r_noise: 20.0 # max energy ratio when applying noise augmentation + min_mix_rate: 0.5 # min ratio of the input audio that would be augmented + max_mix_rate: 0.5 # max ratio of the input audio that would be augmented + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + validation_ds: + manifest_filepath: ??? + noise_manifest: null + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + max_duration: 60.0 + min_duration: 1.0 + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 + noise_ratio: 0.5 + min_r_speech: -5.0 + max_r_speech: 5.0 + min_r_noise: -5.0 + max_r_noise: 20.0 + min_mix_rate: 0.5 + max_mix_rate: 0.5 + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 8 + pad_value: 0.0 + + # spec_augment is not actually used, just to avoid init error + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 0 # set to zero to disable it + time_masks: 0 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + masking: + _target_: nemo.collections.asr.modules.RandomBlockMasking + block_size: 40 # for pre_conv masking, 10ms per frame, 400ms per block with block_size=40 + mask_prob: 0.01 # for allow_overlap=True, this means the mask prob for each frame; otherwise it means the overall masked proportion + feat_in: ${model.preprocessor.features} + freeze: true + allow_overlap: true + + quantizer: + _target_: nemo.collections.asr.modules.RandomProjectionVectorQuantizer + feat_in: ${model.preprocessor.features} + code_dim: ${model.code_dim} + num_books: ${model.num_books} + num_classes: ${model.num_classes} + dist_fn: "l2" # choices=["l2", "cosine"] + freeze: true + squeeze_single: ${model.squeeze_single} + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio + + encoder: + _target_: nemo.collections.asr.modules.ASRTransformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 44 + d_model: 1024 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: true + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + post_ln: false # apply normalization before or after self-attention + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one + # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + att_context_size: [70, 13] # [left, right] context sizes + att_context_style: chunked_limited # regular or chunked_limited + + xscaling: false # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + ### regularization + dropout: 0.1 # The dropout used in most of the Transformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.MultiSoftmaxDecoder + feat_in: ${model.encoder.d_model} + num_classes: ${model.num_classes} + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + use_bias: true + + loss: + _target_: nemo.collections.asr.losses.MultiMLMLoss + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio for 'pre_conv', 1 for 'post_conv' + mask_threshold: 0.8 + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 25000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 1000000 # computed at runtime if not set + val_check_interval: 2500 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 1 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + save_nemo_on_train_end: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: ${exp_manager.name} + project: null diff --git a/examples/asr/conf/ssl/nest/nest_transformer.yaml b/examples/asr/conf/ssl/nest/nest_transformer.yaml new file mode 100644 index 000000000000..0f55e64b1af1 --- /dev/null +++ b/examples/asr/conf/ssl/nest/nest_transformer.yaml @@ -0,0 +1,245 @@ +# This config contains the default values for self-supervised pre-training of a FastConformer model +# +# Here are the recommended configs for different variants of FastConformer, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+--------------+------------+---------+ +# | Model | d_model | n_heads | n_layers | weight_decay | xscaling | use_bias| +# +==============+=========+========+===========+==============+============+=========+ +# | Large (119M) | 512 | 8 | 33 | 1e-3 | False | True | +# +--------------+---------+--------+-----------+--------------+------------+---------+ +# | XLarge (603M)| 1024 | 8 | 44 | 1e-3 | False | False | +# +--------------+---------+--------+-----------+--------------+------------+---------+ +# | XXLarge(1.2B)| 1024 | 8 | 88 | 1e-3 | False | False | +# +--------------+---------+--------+-----------+--------------+------------+---------+ + + +name: "SSL-NEST-Transformer" + +model: + sample_rate: 16000 + num_classes: 8192 + num_books: 1 + code_dim: 16 + squeeze_single: false + mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv'] + + train_ds: + manifest_filepath: ??? # path to training manifest, can be a string or list of strings + noise_manifest: ??? # the manifest for noise data, can be a string or list of strings + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 60.0 + min_duration: 1.0 + drop_last: true + is_concat: false + concat_sampling_technique: temperature + concat_sampling_temperature: 1.0 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 # prob of activating the augmentation + noise_ratio: 0.5 # prob of applying noise aug, otherwise apply speech augmentation + min_r_speech: -5.0 # min energy ratio when applying speech augmentation + max_r_speech: 5.0 # max energy ratio when applying speech augmentation + min_r_noise: -5.0 # min energy ratio when applying noise augmentation + max_r_noise: 20.0 # max energy ratio when applying noise augmentation + min_mix_rate: 0.5 # min ratio of the input audio that would be augmented + max_mix_rate: 0.5 # max ratio of the input audio that would be augmented + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + validation_ds: + manifest_filepath: ??? + noise_manifest: null + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + max_duration: 60.0 + min_duration: 1.0 + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 + noise_ratio: 0.5 + min_r_speech: -5.0 + max_r_speech: 5.0 + min_r_noise: -5.0 + max_r_noise: 20.0 + min_mix_rate: 0.5 + max_mix_rate: 0.5 + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 8 + pad_value: 0.0 + + # spec_augment is not actually used, just to avoid init error + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 0 # set to zero to disable it + time_masks: 0 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + masking: + _target_: nemo.collections.asr.modules.RandomBlockMasking + block_size: 40 # for pre_conv masking, 10ms per frame, 400ms per block with block_size=40 + mask_prob: 0.01 # for allow_overlap=True, this means the mask prob for each frame; otherwise it means the overall masked proportion + feat_in: ${model.preprocessor.features} + freeze: true + allow_overlap: true + + quantizer: + _target_: nemo.collections.asr.modules.RandomProjectionVectorQuantizer + feat_in: ${model.preprocessor.features} + code_dim: ${model.code_dim} + num_books: ${model.num_books} + num_classes: ${model.num_classes} + dist_fn: "l2" # choices=["l2", "cosine"] + freeze: true + squeeze_single: ${model.squeeze_single} + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio + + encoder: + _target_: nemo.collections.asr.modules.ASRTransformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 33 + d_model: 512 + use_bias: true # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + post_ln: false # apply normalization before or after self-attention + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: false # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.MultiSoftmaxDecoder + feat_in: ${model.encoder.d_model} + num_classes: ${model.num_classes} + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + use_bias: true + + loss: + _target_: nemo.collections.asr.losses.MultiMLMLoss + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio for 'pre_conv', 1 for 'post_conv' + mask_threshold: 0.8 + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 25000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 1000000 # computed at runtime if not set + val_check_interval: 2500 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 1 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + save_nemo_on_train_end: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: ${exp_manager.name} + project: null \ No newline at end of file diff --git a/examples/asr/conf/ssl/nest/nest_transformer_xlarge.yaml b/examples/asr/conf/ssl/nest/nest_transformer_xlarge.yaml new file mode 100644 index 000000000000..24e198211a99 --- /dev/null +++ b/examples/asr/conf/ssl/nest/nest_transformer_xlarge.yaml @@ -0,0 +1,246 @@ +# This config contains the default values for self-supervised pre-training of a FastConformer model +# +# Here are the recommended configs for different variants of FastConformer, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+--------------+------------+---------+ +# | Model | d_model | n_heads | n_layers | weight_decay | xscaling | use_bias| +# +==============+=========+========+===========+==============+============+=========+ +# | Large (119M) | 512 | 8 | 33 | 1e-3 | False | True | +# +--------------+---------+--------+-----------+--------------+------------+---------+ +# | XLarge (612M)| 1024 | 8 | 44 | 1e-3 | False | False | +# +--------------+---------+--------+-----------+--------------+------------+---------+ +# | XXLarge(1.2B)| 1024 | 8 | 88 | 1e-3 | False | False | +# +--------------+---------+--------+-----------+--------------+------------+---------+ + + +name: "SSL-NEST-Transformer-XL" + +model: + sample_rate: 16000 + num_classes: 8192 + num_books: 1 + code_dim: 16 + squeeze_single: false + mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv'] + + train_ds: + manifest_filepath: ??? # path to training manifest, can be a string or list of strings + noise_manifest: ??? # the manifest for noise data, can be a string or list of strings + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 60.0 + min_duration: 1.0 + drop_last: true + is_concat: false + concat_sampling_technique: temperature + concat_sampling_temperature: 1.0 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 # prob of activating the augmentation + noise_ratio: 0.5 # prob of applying noise aug, otherwise apply speech augmentation + min_r_speech: -5.0 # min energy ratio when applying speech augmentation + max_r_speech: 5.0 # max energy ratio when applying speech augmentation + min_r_noise: -5.0 # min energy ratio when applying noise augmentation + max_r_noise: 20.0 # max energy ratio when applying noise augmentation + min_mix_rate: 0.5 # min ratio of the input audio that would be augmented + max_mix_rate: 0.5 # max ratio of the input audio that would be augmented + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + validation_ds: + manifest_filepath: ??? + noise_manifest: null + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + max_duration: 60.0 + min_duration: 1.0 + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 + noise_ratio: 0.5 + min_r_speech: -5.0 + max_r_speech: 5.0 + min_r_noise: -5.0 + max_r_noise: 20.0 + min_mix_rate: 0.5 + max_mix_rate: 0.5 + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 8 + pad_value: 0.0 + + # spec_augment is not actually used, just to avoid init error + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 0 # set to zero to disable it + time_masks: 0 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + masking: + _target_: nemo.collections.asr.modules.RandomBlockMasking + block_size: 40 # for pre_conv masking, 10ms per frame, 400ms per block with block_size=40 + mask_prob: 0.01 # for allow_overlap=True, this means the mask prob for each frame; otherwise it means the overall masked proportion + feat_in: ${model.preprocessor.features} + freeze: true + allow_overlap: true + + quantizer: + _target_: nemo.collections.asr.modules.RandomProjectionVectorQuantizer + feat_in: ${model.preprocessor.features} + code_dim: ${model.code_dim} + num_books: ${model.num_books} + num_classes: ${model.num_classes} + dist_fn: "l2" # choices=["l2", "cosine"] + freeze: true + squeeze_single: ${model.squeeze_single} + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio + + encoder: + _target_: nemo.collections.asr.modules.ASRTransformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 44 + d_model: 1024 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + post_ln: false # apply normalization before or after self-attention + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: false # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.MultiSoftmaxDecoder + feat_in: ${model.encoder.d_model} + num_classes: ${model.num_classes} + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + use_bias: true + + loss: + _target_: nemo.collections.asr.losses.MultiMLMLoss + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio for 'pre_conv', 1 for 'post_conv' + mask_threshold: 0.8 + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 25000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 1000000 # computed at runtime if not set + val_check_interval: 2500 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + use_distributed_sampler: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 1 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + save_nemo_on_train_end: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: ${exp_manager.name} + project: null \ No newline at end of file diff --git a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py index 5ab224d9d33a..d4447b96ba0f 100644 --- a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py +++ b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import lightning.pytorch as pl from omegaconf import OmegaConf @@ -20,7 +19,7 @@ from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager - +from nemo.utils.trainer_utils import resolve_trainer_cfg """ # Example of training a self-supervised denoising masksed token prediction model @@ -51,12 +50,12 @@ def main(cfg): logging.info(f"Hydra config: {OmegaConf.to_yaml(cfg)}") - trainer = pl.Trainer(**cfg.trainer) + trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer)) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecDenoiseMaskedTokenPredModel(cfg=cfg.model, trainer=trainer) # Initialize the weights of the model from another model, if provided via config - asr_model.maybe_init_from_pretrained_checkpoint(cfg) + asr_model.maybe_init_from_pretrained_checkpoint(cfg, weights_only=False) trainer.fit(asr_model) diff --git a/nemo/collections/asr/data/ssl_dataset.py b/nemo/collections/asr/data/ssl_dataset.py index 2fc644f19703..115a691a6dbc 100644 --- a/nemo/collections/asr/data/ssl_dataset.py +++ b/nemo/collections/asr/data/ssl_dataset.py @@ -22,7 +22,9 @@ import numpy as np import torch +from lhotse.cut import Cut, CutSet, MixedCut, MultiCut from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors from omegaconf import DictConfig, ListConfig, open_dict from torch import Tensor @@ -38,6 +40,16 @@ @dataclass class AudioNoiseItem: + """ + A single audio noise item. + Args: + sample_id: the sample id + audio: the audio tensor + audio_len: the length of the audio + noise: the noise tensor + noise_len: the length of the noise + """ + sample_id: str | None = None audio: Union[Tensor, None] = None audio_len: Union[Tensor, None] = None @@ -49,6 +61,16 @@ class AudioNoiseItem: @dataclass class AudioNoiseBatch: + """ + A batch of audio noise items. + Args: + sample_id: the sample id + audio: the audio tensor + audio_len: the length of the audio + noise: the noise tensor + noise_len: the length of the noise + """ + sample_id: list | None = None audio: Union[Tensor, None] = None audio_len: Union[Tensor, None] = None @@ -95,7 +117,18 @@ def _parse_manifest_item(line: str, manifest_file: str) -> Dict[str, Any]: return item -def _audio_noise_collate_fn(batch: List[AudioNoiseItem], batch_augmentor: Any = None) -> AudioNoiseBatch: +def _audio_noise_collate_fn( + batch: List[AudioNoiseItem], batch_augmentor: Any = None, return_noise: bool = False +) -> AudioNoiseBatch: + """ + Collate a batch of audio noise items into a batch of audio noise batches. + Args: + batch: the batch of audio noise items + batch_augmentor: the batch augmentor + return_noise: whether to return the noises + Returns: + the batch of audio noise batches + """ audios = [x.audio for x in batch] audio_lengths = [x.audio_len for x in batch] max_audio_len = max(audio_lengths).item() @@ -137,6 +170,10 @@ def _audio_noise_collate_fn(batch: List[AudioNoiseItem], batch_augmentor: Any = output.noisy_audio = output.audio + output.noise output.noisy_audio_len = output.audio_len + if not return_noise: + output.noise = None + output.noise_len = None + return output @@ -166,7 +203,7 @@ def load_noise_audio( pad_to_max: bool = True, min_white_noise_db: int = -90, max_white_noise_db: int = -46, - max_trial: int = 100, + max_trial: int = 1, ): """ Load noise audio from the manifest item, and apply white noise if the loaded noise audio is empty. @@ -230,7 +267,7 @@ def load_noise_audio( return noise, noise_len -def sample_noise(noise_data: List[Dict], sample_rate: int, max_audio_len: int | None = None, max_trial: int = 20): +def sample_noise(noise_data: List[Dict], sample_rate: int, max_audio_len: int | None = None, max_trial: int = 1): """ Randomly sample noise audio from the noise manifest. Args: @@ -293,6 +330,7 @@ def __init__( batch_augmentor: Any | None = None, min_audio_len_secs: float = 1.0, pad_audio_mode: str = 'repeat', + return_noise: bool = False, **kwargs, ): # add bos_id=0 to avoid empty text token @@ -302,6 +340,7 @@ def __init__( self.noise_data = load_noise_manifest(noise_manifest) self.min_audio_len_secs = min_audio_len_secs self.pad_audio_mode = pad_audio_mode + self.return_noise = return_noise def __getitem__(self, index) -> AudioNoiseItem: sample = self.manifest_processor.collection[index] @@ -336,7 +375,7 @@ def __getitem__(self, index) -> AudioNoiseItem: return item def _collate_fn(self, batch: List[AudioNoiseItem]) -> AudioNoiseBatch: - return _audio_noise_collate_fn(batch, self.batch_augmentor) + return _audio_noise_collate_fn(batch, self.batch_augmentor, self.return_noise) class TarredAudioNoiseDataset(audio_to_text.TarredAudioToCharDataset): @@ -351,6 +390,7 @@ def __init__( batch_augmentor: Any | None = None, min_audio_len_secs: float = 1.0, pad_audio_mode: str = 'repeat', + return_noise: bool = False, **kwargs, ): """ @@ -359,6 +399,7 @@ def __init__( batch_augmentor: the batch augmentor min_audio_len_secs: the minimum audio length in seconds, audios shorter than this will be padded pad_audio_mode: the padding mode for audios shorter than min_audio_len_secs, either 'repeat' or 'zero' + return_noise: whether to return the noise in output batch, default is False **kwargs: other arguments for TarredAudioToCharDataset """ @@ -368,6 +409,7 @@ def __init__( self.noise_data = load_noise_manifest(noise_manifest) self.min_audio_len_secs = min_audio_len_secs self.pad_audio_mode = pad_audio_mode + self.return_noise = return_noise def _build_sample(self, tup): """Builds the training sample by combining the data from the WebDataset with the manifest info.""" @@ -425,11 +467,99 @@ def _pad_audio(self, audio: Tensor) -> Tensor: return audio def _collate_fn(self, batch: List[AudioNoiseItem]) -> AudioNoiseBatch: - return _audio_noise_collate_fn(batch, self.batch_augmentor) + return _audio_noise_collate_fn(batch, self.batch_augmentor, self.return_noise) + + +def maybe_convert_cuts_to_mono(cuts: CutSet) -> CutSet: + """ + Convert the cuts to mono if they are not already mono. + Args: + cuts: the cuts to convert + Returns: + the converted cuts + """ + resolved_cuts = [] + import pdb + + pdb.set_trace() + for cut in cuts: + try: + resolved_cuts.append(cut.move_to_memory()) + except Exception: + if isinstance(cut, MixedCut): + cut.first_non_padding_cut.recording.sources[0].channel_ids = [0, 1] + cut.first_non_padding_cut = MultiCut.from_dict(cut.first_non_padding_cut.to_dict()) + else: + cut.recording.sources[0].channel_ids = [0, 1] + cut = MultiCut.from_dict(cut.to_dict()) + try: + resolved_cuts.append(cut.to_mono(mono_downmix=True)) + except Exception as e: + logging.warning(f"Error converting cut to mono: {cut}, with exception: {e}. Skipping this cut.") + continue + resolved_cuts = CutSet(resolved_cuts) + return resolved_cuts + + +def safe_load_and_convert_to_mono(cut: Cut) -> Tensor: + """ + Load the audio safely. + Args: + cut: the cut to load + Returns: + the loaded audio + """ + try: + audio = cut.load_audio() + if audio.ndim == 2: + audio = audio.mean(axis=0) + return audio + except Exception as e: + logging.warning(f"Error loading audio: {cut}, with exception: {e}. Skipping this cut.") + return None + + +def safe_collate_audios(cuts: CutSet) -> tuple[Tensor, Tensor, CutSet]: + """ + Collate the audios safely. + Args: + cuts: the cuts to collate + Returns: + the collated audios, audio lengths, and cuts + """ + loaded_audios = [] + loaded_audio_lens = [] + loaded_cuts = [] + for cut in cuts: + audio = safe_load_and_convert_to_mono(cut) + if audio is not None: + loaded_audios.append(audio) + loaded_audio_lens.append(audio.shape[0]) + loaded_cuts.append(cut) + + if len(loaded_audios) == 0: + return None, None, None + loaded_audios = collate_vectors(loaded_audios) + loaded_audio_lens = torch.tensor(loaded_audio_lens).long() + loaded_cuts = CutSet(loaded_cuts) + return loaded_audios, loaded_audio_lens, loaded_cuts class LhotseAudioNoiseDataset(torch.utils.data.Dataset): - def __init__(self, noise_manifest: str | None = None, batch_augmentor_cfg: DictConfig = None): + def __init__( + self, + cfg: DictConfig, + noise_manifest: Optional[Union[str, ListConfig]] = None, + batch_augmentor_cfg: DictConfig = None, + return_noise: bool = False, + ): + """ + Args: + cfg: the dataset config + noise_manifest: the noise manifest file or list of noise manifest files + batch_augmentor_cfg: the batch augmentor config + return_noise: whether to return the noise in output batch, default is False + """ super().__init__() if batch_augmentor_cfg: @@ -439,13 +569,23 @@ def __init__(self, noise_manifest: str | None = None, batch_augmentor_cfg: DictC self.batch_augmentor = batch_augmentor self.noise_data = load_noise_manifest(noise_manifest) - self.load_audio = AudioSamples(fault_tolerant=True) + self.load_audio = AudioSamples(fault_tolerant=True, use_batch_loader=True, mono_downmix=True) + self.return_noise = return_noise + self.cfg = cfg + + def __getitem__(self, cuts: CutSet) -> AudioNoiseBatch: + if self.cfg.get("use_ais_get_batch", False): + cuts = cuts.to_eager() + audios, audio_lens, cuts = self.load_audio(cuts) + else: + audios, audio_lens, cuts = safe_collate_audios(cuts) - def __getitem__(self, cuts): + if audios is None: + return None - audios, audio_lens, cuts = self.load_audio(cuts) + max_audio_len = audios.shape[1] if len(self.noise_data) > 0: - sampled_noises = [sample_noise(self.noise_data, cut.sampling_rate, cut.num_samples) for cut in cuts] + sampled_noises = [sample_noise(self.noise_data, self.cfg["sample_rate"], max_audio_len) for _ in cuts] sampled_noises, sampled_noises_lens = zip(*sampled_noises) sampled_noises = torch.stack(sampled_noises).float() sampled_noises_lens = torch.tensor(sampled_noises_lens).long() @@ -466,12 +606,23 @@ def __getitem__(self, cuts): output.noisy_audio = output.audio + output.noise output.noisy_audio_len = output.audio_len + if not self.return_noise: + output.noise = None + output.noise_len = None + return output def get_audio_noise_dataset( - config: Dict[str, Any], augmentor: Any = None, batch_augmentor: Any = None + config: DictConfig, augmentor: Any = None, batch_augmentor: Any = None, return_noise: bool = False ) -> AudioNoiseDataset: + """ + Args: + config: the dataset config + augmentor: the audio augmentor + batch_augmentor: the batch augmentor + return_noise: whether to return the noise in output batch, default is False + """ dataset = AudioNoiseDataset( noise_manifest=config.get('noise_manifest', None), batch_augmentor=batch_augmentor, @@ -484,13 +635,28 @@ def get_audio_noise_dataset( min_duration=config.get('min_duration', None), trim=config.get('trim_silence', False), channel_selector=config.get('channel_selector', None), + return_noise=return_noise, ) return dataset def get_concat_audio_noise_dataset( - config: Dict[str, Any], global_rank: int, world_size: int, augmentor: Any = None, batch_augmentor: Any = None + config: DictConfig, + global_rank: int, + world_size: int, + augmentor: Any = None, + batch_augmentor: Any = None, + return_noise: bool = False, ) -> ConcatDataset: + """ + Args: + config: the dataset config + global_rank: the global rank + world_size: the global world size + augmentor: the audio augmentor + batch_augmentor: the batch augmentor + return_noise: whether to return the noise in output batch, default is False + """ manifest_filepaths = config['manifest_filepath'] datasets = [] @@ -504,7 +670,9 @@ def get_concat_audio_noise_dataset( conf = copy.deepcopy(config) conf['manifest_filepath'] = manifest_filepath - dataset = get_audio_noise_dataset(config=conf, augmentor=augmentor) + dataset = get_audio_noise_dataset( + config=conf, augmentor=augmentor, batch_augmentor=batch_augmentor, return_noise=return_noise + ) datasets.append(dataset) dataset = ConcatDataset( @@ -521,7 +689,25 @@ def get_concat_audio_noise_dataset( return dataset -def get_tarred_audio_noise_dataset(config, shuffle_n, global_rank, world_size, augmentor, batch_augmentor: Any = None): +def get_tarred_audio_noise_dataset( + config: DictConfig, + shuffle_n: int, + global_rank: int, + world_size: int, + augmentor: Any = None, + batch_augmentor: Any = None, + return_noise: bool = False, +): + """ + Args: + config: the dataset config + shuffle_n: the number of samples to look ahead and load to be shuffled + global_rank: the global rank + world_size: the global world size + augmentor: the audio augmentor + batch_augmentor: the batch augmentor + return_noise: whether to return the noise in output batch, default is False + """ tarred_audio_filepaths = config['tarred_audio_filepaths'] manifest_filepaths = config['manifest_filepath'] datasets = [] @@ -568,6 +754,7 @@ def get_tarred_audio_noise_dataset(config, shuffle_n, global_rank, world_size, a shard_manifests=is_sharded_manifest, global_rank=global_rank, world_size=world_size, + return_noise=return_noise, ) if bucketing_weights: [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])] @@ -578,8 +765,24 @@ def get_tarred_audio_noise_dataset(config, shuffle_n, global_rank, world_size, a def get_concat_tarred_audio_noise_dataset( - config, shuffle_n, global_rank, world_size, augmentor, batch_augmentor: Any = None + config: DictConfig, + shuffle_n: int, + global_rank: int, + world_size: int, + augmentor: Any = None, + batch_augmentor: Any = None, + return_noise: bool = False, ): + """ + Args: + config: the dataset config + shuffle_n: the number of samples to look ahead and load to be shuffled + global_rank: the global rank + world_size: the global world size + augmentor: the audio augmentor + batch_augmentor: the batch augmentor + return_noise: whether to return the noise in output batch, default is False + """ tarred_audio_filepaths = config['tarred_audio_filepaths'] manifest_filepaths = config['manifest_filepath'] datasets = [] @@ -596,6 +799,7 @@ def get_concat_tarred_audio_noise_dataset( world_size=world_size, augmentor=augmentor, batch_augmentor=batch_augmentor, + return_noise=return_noise, ) datasets.append(dataset) @@ -614,10 +818,18 @@ def get_concat_tarred_audio_noise_dataset( def get_audio_noise_dataset_from_config( - config, + config: DictConfig, global_rank: int, world_size: int, + return_noise: bool = False, ): + """ + Args: + config: the dataset config + global_rank: the global rank + world_size: the global world size + return_noise: whether to return the noise in output batch, default is False + """ if 'augmentor' in config: augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size) else: @@ -670,6 +882,7 @@ def get_audio_noise_dataset_from_config( world_size=world_size, augmentor=augmentor, batch_augmentor=batch_augmentor, + return_noise=return_noise, ) else: dataset = get_tarred_audio_noise_dataset( @@ -679,6 +892,7 @@ def get_audio_noise_dataset_from_config( world_size=world_size, augmentor=augmentor, batch_augmentor=batch_augmentor, + return_noise=return_noise, ) else: if 'manifest_filepath' in config and config['manifest_filepath'] is None: @@ -691,7 +905,10 @@ def get_audio_noise_dataset_from_config( world_size=world_size, augmentor=augmentor, batch_augmentor=batch_augmentor, + return_noise=return_noise, ) else: - dataset = get_audio_noise_dataset(config=config, augmentor=augmentor, batch_augmentor=batch_augmentor) + dataset = get_audio_noise_dataset( + config=config, augmentor=augmentor, batch_augmentor=batch_augmentor, return_noise=return_noise + ) return dataset diff --git a/nemo/collections/asr/losses/ssl_losses/mlm.py b/nemo/collections/asr/losses/ssl_losses/mlm.py index 4ed6f580bbb2..1415fc6a5b72 100644 --- a/nemo/collections/asr/losses/ssl_losses/mlm.py +++ b/nemo/collections/asr/losses/ssl_losses/mlm.py @@ -61,6 +61,18 @@ def __init__( def forward( self, decoder_outputs, targets, decoder_lengths=None, target_lengths=None, spec_masks=None, masks=None ): + """ + Args: + decoder_outputs: (B, T, D) + targets: (B, T) + decoder_lengths: (B,) + target_lengths: (B,) + spec_masks: (B, D, T) + masks: (B, D, T) + + Returns: + loss: (1,) + """ if masks is None: masks = spec_masks @@ -71,11 +83,22 @@ def forward( # B,D,T -> B,T,D masks = masks.transpose(1, 2) - masks = masks.reshape(masks.shape[0], masks.shape[1] // self.combine_time_steps, -1) - masks = masks.mean(-1) > self.mask_threshold + masks = masks.reshape(masks.shape[0], masks.shape[1] // self.combine_time_steps, -1) + masks = masks.mean(-1) > self.mask_threshold # (B, T) + + # Truncate the length of decoder_outputs, masks, and targets to the minimum length of the three. + loss_length = min(decoder_outputs.shape[1], masks.shape[1], targets.shape[1]) + + decoder_outputs = decoder_outputs[:, :loss_length] + masks = masks[:, :loss_length] + targets = targets[:, :loss_length] + + if decoder_lengths is not None: + decoder_lengths = torch.clamp(decoder_lengths, max=loss_length) + if target_lengths is not None: + target_lengths = torch.clamp(target_lengths, max=loss_length) out_masked_only = decoder_outputs[masks] - targets = F.pad(targets, (0, masks.shape[-1] - targets.shape[-1])) targets_masked_only = targets[masks] loss = self.nll_loss(out_masked_only, targets_masked_only) diff --git a/nemo/collections/asr/models/ssl_models.py b/nemo/collections/asr/models/ssl_models.py index 6e149c3c17b8..f59ac0d84c4f 100644 --- a/nemo/collections/asr/models/ssl_models.py +++ b/nemo/collections/asr/models/ssl_models.py @@ -13,7 +13,7 @@ # limitations under the License. from math import ceil -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -761,6 +761,16 @@ def training_step(self, batch, batch_idx=0): 'train_loss': loss_value, } + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + sample_id = self._trainer.global_step + else: + log_every_n_steps = 1 + sample_id = batch_idx + + if self.cfg.get("log_codebook_coverage", False) and (sample_id + 1) % log_every_n_steps == 0: + self.log_codebook_coverage(tokens, encoded_len) + return {'loss': loss_value, 'log': tensorboard_logs} def inference_pass(self, batch, batch_idx=0, dataloader_idx=0, mode='val', apply_mask=False): @@ -821,6 +831,65 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): tensorboard_logs = {'test_loss': test_loss_mean} return {'test_loss': test_loss_mean, 'log': tensorboard_logs} + def log_codebook_coverage(self, tokens: torch.Tensor, encoded_len: torch.Tensor): + if tokens.ndim == 2: + _tokens = tokens.unsqueeze(-1) # shape [B, T, 1] + else: + _tokens = tokens # shape [B, T, N] + + # find the number of unique tokens in the batch, excluding padding; count per codebook (batch collapsed) -> [N] + # encoded_len is of shape [B]; valid positions: t < encoded_len[b] for each batch item b + T, N = _tokens.size(1), _tokens.size(2) + valid_mask = torch.arange(T, device=_tokens.device, dtype=encoded_len.dtype).unsqueeze( + 0 + ) < encoded_len.unsqueeze(1) + + # Initialize cumulative token count if it doesn't exist, it's okay that each training job will reset it, + # as long as the curve for each job is showing increasing cumulative coverage. + if getattr(self, 'cumulative_token_count', None) is None: + self.cumulative_token_count = torch.zeros( + [N, int(self.cfg.num_classes)], device=_tokens.device, dtype=torch.long + ) + + # This batch's token counts per codebook [N, num_classes]; will be summed across ranks then added to cumulative + all_valid = _tokens[valid_mask] # (num_valid, N) — index once outside loop + batch_token_count = torch.zeros([N, int(self.cfg.num_classes)], device=_tokens.device, dtype=torch.long) + for n in range(N): + valid_tokens_n = all_valid[:, n] + if valid_tokens_n.numel() > 0: + batch_token_count[n] = torch.bincount(valid_tokens_n.long(), minlength=int(self.cfg.num_classes)) + # Derive unique count from bincount — no separate .unique() call + num_unique_tokens = (batch_token_count > 0).sum(dim=1) + + # Sync batch counts across ranks (SUM) so cumulative is global + if torch.distributed.is_initialized(): + torch.distributed.all_reduce(batch_token_count, op=torch.distributed.ReduceOp.SUM) + self.cumulative_token_count += batch_token_count + + # Get the maximum number of unique tokens for each codebook across all ranks + if torch.distributed.is_initialized(): + torch.distributed.all_reduce(num_unique_tokens, op=torch.distributed.ReduceOp.MAX) + + # coverage = fraction of codebook entries used + codebook_coverage = num_unique_tokens / float(self.cfg.num_classes) + cumulative_codebook_coverage = torch.count_nonzero(self.cumulative_token_count, dim=1) / float( + self.cfg.num_classes + ) + cumulative_token_prob = self.cumulative_token_count / self.cumulative_token_count.sum( + dim=1, keepdim=True + ) # shape [N, num_classes] + # Entropy H = -sum(p * log(p)) per codebook; use clamp to avoid log(0) + cumulative_codebook_entropy = -(cumulative_token_prob * torch.log(cumulative_token_prob.clamp(min=1e-10))).sum( + dim=1 + ) + for n in range(N): + self.log(f"codebook_coverage_cb{n}", codebook_coverage[n].float()) + self.log(f"cumul_codebook_coverage_cb{n}", cumulative_codebook_coverage[n].float()) + self.log(f"cumul_codebook_entropy_cb{n}", cumulative_codebook_entropy[n].float()) + self.log(f"codebook_coverage_avg", codebook_coverage.mean().float()) + self.log(f"cumul_codebook_coverage_avg", cumulative_codebook_coverage.mean().float()) + self.log(f"cumul_codebook_entropy_avg", cumulative_codebook_entropy.mean().float()) + class EncDecDenoiseMaskedTokenPredModel(EncDecMaskedTokenPredModel): """ @@ -839,8 +908,6 @@ def oomptimizer_schema(self) -> dict: "inputs": [ {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "audio"}, {"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "audio_len"}, - {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "noise"}, - {"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "noise_len"}, {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "noisy_audio"}, {"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "noisy_audio_len"}, ], @@ -858,8 +925,10 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): global_rank=self.global_rank, world_size=self.world_size, dataset=ssl_dataset.LhotseAudioNoiseDataset( + cfg=config, noise_manifest=config.get('noise_manifest', None), batch_augmentor_cfg=config.get('batch_augmentor', None), + return_noise=False, ), ) @@ -867,6 +936,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config, global_rank=self.global_rank, world_size=self.world_size, + return_noise=False, ) shuffle = config['shuffle'] @@ -945,7 +1015,29 @@ def forward( processed_noisy_input_signal=None, processed_noisy_input_signal_length=None, apply_mask=False, - ): + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for the model. + Args: + input_signal: Input signal of shape [B, T]. + input_signal_length: Lengths of the input signal of shape [B]. + processed_signal: Processed signal of shape [B, D, T]. + processed_signal_length: Lengths of the processed signal of shape [B]. + noise_signal: Noise signal of shape [B, T]. + noise_signal_length: Lengths of the noise signal of shape [B]. + processed_noise_signal: Processed noise signal of shape [B, D, T]. + processed_noise_signal_length: Lengths of the processed noise signal of shape [B]. + noisy_input_signal: Noisy input signal of shape [B, T]. + noisy_input_signal_length: Lengths of the noisy input signal of shape [B]. + processed_noisy_input_signal: Processed noisy input signal of shape [B, D, T]. + processed_noisy_input_signal_length: Lengths of the processed noisy input signal of shape [B]. + apply_mask: Whether to apply masking to the input signal. + Returns: + log_probs: Log probabilities of the model of shape [B, T, C]. + encoded_len: Lengths of the encoded signal of shape [B]. + masks: Masks of the model of shape [B, D, T]. + tokens: Target tokens of the model of shape [B, T, N] or [B, T] if num_books == 1 and squeeze_single is True. + """ has_input_signal = input_signal is not None and input_signal_length is not None has_processed_signal = processed_signal is not None and processed_signal_length is not None if (has_input_signal ^ has_processed_signal) == False: @@ -1033,6 +1125,16 @@ def training_step(self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int): 'train_loss': loss_value, } + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + sample_id = self._trainer.global_step + else: + log_every_n_steps = 1 + sample_id = batch_idx + + if self.cfg.get("log_codebook_coverage", False) and (sample_id + 1) % log_every_n_steps == 0: + self.log_codebook_coverage(tokens, encoded_len) + return {'loss': loss_value, 'log': tensorboard_logs} def inference_pass( diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index 7259d077809e..51846e51bd8e 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.asr.modules.asr_transformer_encoder import ASRTransformerEncoder from nemo.collections.asr.modules.audio_preprocessing import ( # noqa: F401 AudioToMelSpectrogramPreprocessor, AudioToMFCCPreprocessor, diff --git a/nemo/collections/asr/modules/asr_transformer_encoder.py b/nemo/collections/asr/modules/asr_transformer_encoder.py new file mode 100644 index 000000000000..bd847caba794 --- /dev/null +++ b/nemo/collections/asr/modules/asr_transformer_encoder.py @@ -0,0 +1,1250 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random +from collections import OrderedDict +from dataclasses import dataclass +from typing import List, Optional, Set, Tuple + +import torch +import torch.distributed +import torch.nn as nn +from omegaconf import DictConfig, ListConfig, open_dict + +from nemo.collections.asr.models.configs import CacheAwareStreamingConfig +from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder +from nemo.collections.asr.parts.submodules.asr_transformer_modules import ASRTransformerLayer +from nemo.collections.asr.parts.submodules.multi_head_attention import ( + LocalAttRelPositionalEncoding, + MultiHeadAttention, + PositionalEncoding, + RelPositionalEncoding, + RelPositionMultiHeadAttention, + RelPositionMultiHeadAttentionLongformer, +) +from nemo.collections.asr.parts.submodules.subsampling import ( + ConvSubsampling, + StackingSubsampling, + SubsamplingReductionModule, +) +from nemo.collections.asr.parts.utils import adapter_utils +from nemo.collections.asr.parts.utils.regularization_utils import compute_stochastic_depth_drop_probs +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.mixins import AccessMixin, adapter_mixins +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import AcousticEncodedRepresentation, ChannelType, LengthsType, NeuralType, SpectrogramType +from nemo.utils import logging + +__all__ = ['ASRTransformerEncoder'] + + +class ASRTransformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): + """ + ASR encoder based on Transformer. + + Args: + feat_in (int): the size of feature channels + n_layers (int): number of layers of ASRTransformerLayer + d_model (int): the hidden size of the model + feat_out (int): the size of the output features + Defaults to -1 (means feat_out is d_model) + subsampling (str): the method of subsampling, choices=['vggnet', 'striding', 'dw-striding', 'stacking', 'stacking_norm'] + Defaults to striding. + subsampling_factor (int): the subsampling factor which should be power of 2 + Defaults to 4. + subsampling_conv_chunking_factor(int): optionally, force chunk inputs (helpful for large inputs) + Should be power of 2, 1 (auto-chunking, default), or -1 (no chunking) + subsampling_conv_channels (int): the size of the convolutions in the subsampling module + Defaults to -1 which would set it to d_model. + reduction (str, Optional): the method of reduction, choices=['pooling', 'striding']. If no value + is passed, then no reduction is performed and the models runs with the original 4x subsampling. + reduction_position (int, Optional): the index of the layer to apply reduction. If -1, apply reduction + at the end. + reduction_factor (int): the reduction factor which should be either 1 or a power of 2 + Defaults to 1. + ff_expansion_factor (int): the expansion factor in feed forward layers + Defaults to 4. + self_attention_model (str): type of the attention layer and positional encoding + + 'rel_pos': + relative positional embedding and Transformer-XL + + 'rel_pos_local_attn': + relative positional embedding and Transformer-XL with local attention using + overlapping chunks. Attention context is determined by att_context_size parameter. + + 'abs_pos': + absolute positional embedding and Transformer + + Default is rel_pos. + pos_emb_max_len (int): the maximum length of positional embeddings + Defaults to 5000 + n_heads (int): number of heads in multi-headed attention layers + Defaults to 4. + att_context_size (List[Union[List[int],int]]): specifies the context sizes on each side. Each context size should be a list of two integers like [100,100]. + A list of context sizes like [[100,100],[100,50]] can also be passed. -1 means unlimited context. + Defaults to [-1,-1] + att_context_probs (List[float]): a list of probabilities of each one of the att_context_size when a list of them is passed. If not specified, uniform distribution is being used. + Defaults to None + att_context_style (str): 'regular' or 'chunked_limited'. + Defaults to 'regular' + xscaling (bool): enables scaling the inputs to the multi-headed attention layers by sqrt(d_model) + Defaults to True. + untie_biases (bool): whether to not share (untie) the bias weights between layers of Transformer-XL + Defaults to True. + conv_kernel_size (int): the size of the convolutions in the convolutional modules + Defaults to 31. + conv_norm_type (str): the type of the normalization in the convolutional modules + Defaults to 'batch_norm'. + conv_context_size (list): it can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size. + None means [(conv_kernel_size-1)//2, (conv_kernel_size-1)//2], and 'causal' means [(conv_kernel_size-1), 0]. + Defaults to None. + conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used. When enables, the left half of the convolution kernel would get masked in streaming cases. + Defaults to False + use_bias (bool): Use bias in all Linear ASRTransformerLayer to improve activation flow and stabilize training of huge models. + Defaults to True. + dropout (float): the dropout rate used in all layers except the attention layers + Defaults to 0.1. + dropout_pre_encoder (float): the dropout rate used before the encoder + Defaults to 0.1. + dropout_emb (float): the dropout rate used for the positional embeddings + Defaults to 0.1. + dropout_att (float): the dropout rate used for the attention layer + Defaults to 0.0. + stochastic_depth_drop_prob (float): if non-zero, will randomly drop + layers during training. The higher this value, the more often layers + are dropped. Defaults to 0.0. + stochastic_depth_mode (str): can be either "linear" or "uniform". If + set to "uniform", all layers have the same probability of drop. If + set to "linear", the drop probability grows linearly from 0 for the + first layer to the desired value for the final layer. Defaults to + "linear". + stochastic_depth_start_layer (int): starting layer for stochastic depth. + All layers before this will never be dropped. Note that drop + probability will be adjusted accordingly if mode is "linear" when + start layer is > 1. Defaults to 1. + global_tokens (int): number of tokens to be used for global attention. + Only relevant if self_attention_model is 'rel_pos_local_attn'. + Defaults to 0. + global_tokens_spacing (int): how far apart the global tokens are + Defaults to 1. + global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate. + Defaults to False. + use_pytorch_sdpa (bool): use torch sdpa instead of manual attention. + Defaults to False. + use_pytorch_sdpa_backends (list[str]): list of backend names to use in sdpa. None or empty list means all backends. e.g. ["MATH"] + Defaults to None + sync_max_audio_length (bool): when true, performs NCCL all_reduce to allocate the same amount of memory for + positional encoding buffers on all GPUs. Disabling this setting may help with deadlocks in certain + scenarios such as model parallelism, or generally when this module is not being ran on some GPUs + as a part of the training step. + + """ + + def input_example(self, max_batch=1, max_dim=256): + """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ + dev = next(self.parameters()).device + if self.export_cache_support: + window_size = max_dim + if self.streaming_cfg is not None: + if isinstance(self.streaming_cfg.chunk_size, list): + chunk_size = self.streaming_cfg.chunk_size[1] + else: + chunk_size = self.streaming_cfg.chunk_size + if isinstance(self.streaming_cfg.pre_encode_cache_size, list): + pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size[1] + else: + pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size + window_size = chunk_size + pre_encode_cache_size + input_example = torch.randn(max_batch, self._feat_in, window_size, device=dev) + input_example_length = torch.randint( + window_size // 4, window_size, (max_batch,), device=dev, dtype=torch.int64 + ) + cache_last_channel, cache_last_time, cache_last_channel_len = self.get_initial_cache_state( + batch_size=max_batch, device=dev, max_dim=max_dim + ) + all_input_example = tuple( + [ + input_example, + input_example_length, + cache_last_channel.transpose(0, 1), + cache_last_time.transpose(0, 1), + cache_last_channel_len, + ] + ) + else: + input_example = torch.randn(max_batch, self._feat_in, max_dim, device=dev) + input_example_length = torch.randint(max_dim // 4, max_dim, (max_batch,), device=dev, dtype=torch.int64) + all_input_example = tuple([input_example, input_example_length]) + + return all_input_example + + @property + def input_types(self): + """Returns definitions of module input ports.""" + return OrderedDict( + { + "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + "cache_last_channel": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True), + "cache_last_time": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True), + "cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True), + } + ) + + @property + def input_types_for_export(self): + """Returns definitions of module input ports.""" + return OrderedDict( + { + "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + "cache_last_channel": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True), + "cache_last_time": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True), + "cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True), + } + ) + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return OrderedDict( + { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "cache_last_channel_next": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True), + "cache_last_time_next": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True), + "cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True), + } + ) + + @property + def output_types_for_export(self): + """Returns definitions of module output ports.""" + return OrderedDict( + { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "cache_last_channel_next": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True), + "cache_last_time_next": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True), + "cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True), + } + ) + + @property + def disabled_deployment_input_names(self): + if not self.export_cache_support: + return set(["cache_last_channel", "cache_last_time", "cache_last_channel_len"]) + else: + return set() + + @property + def disabled_deployment_output_names(self): + if not self.export_cache_support: + return set(["cache_last_channel_next", "cache_last_time_next", "cache_last_channel_next_len"]) + else: + return set() + + def __init__( + self, + feat_in, + n_layers, + d_model, + feat_out=-1, + ffn_act='relu', + post_ln=True, + causal_downsampling=False, + subsampling='striding', + subsampling_factor=4, + subsampling_conv_chunking_factor=1, + subsampling_conv_channels=-1, + reduction=None, + reduction_position=None, + reduction_factor=1, + ff_expansion_factor=4, + self_attention_model='rel_pos', + n_heads=4, + att_context_size=None, + att_context_probs=None, + att_context_style='regular', + xscaling=True, + untie_biases=True, + pos_emb_max_len=5000, + use_bias=True, + dropout=0.1, + dropout_pre_encoder=0.1, + dropout_emb=0.1, + dropout_att=0.0, + stochastic_depth_drop_prob: float = 0.0, + stochastic_depth_mode: str = "linear", + stochastic_depth_start_layer: int = 1, + global_tokens: int = 0, + global_tokens_spacing: int = 1, + global_attn_separate: bool = False, + use_pytorch_sdpa: bool = False, + use_pytorch_sdpa_backends=None, + sync_max_audio_length: bool = True, + ): + super().__init__() + d_ff = d_model * ff_expansion_factor + self.d_model = d_model + self.n_layers = n_layers + self._feat_in = feat_in + self.att_context_style = att_context_style + self.subsampling_factor = subsampling_factor + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + self.self_attention_model = self_attention_model + self.global_tokens = global_tokens + self.global_attn_separate = global_attn_separate + self.global_tokens_spacing = global_tokens_spacing + self.use_pytorch_sdpa = use_pytorch_sdpa + if use_pytorch_sdpa_backends is None: + use_pytorch_sdpa_backends = [] + self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends + self.sync_max_audio_length = sync_max_audio_length + + # Setting up the att_context_size + ( + self.att_context_size_all, + self.att_context_size, + self.att_context_probs, + self.conv_context_size, + ) = self._calc_context_sizes( + att_context_style=att_context_style, + att_context_size=att_context_size, + att_context_probs=att_context_probs, + conv_context_size=None, + conv_kernel_size=1, + ) + + if xscaling: + self.xscale = math.sqrt(d_model) + else: + self.xscale = None + + # Subsampling + if subsampling_conv_channels == -1: + subsampling_conv_channels = d_model + if subsampling and subsampling_factor > 1: + if subsampling in ['stacking', 'stacking_norm']: + # stacking_norm has an extra layer norm after stacking comparing to stacking + self.pre_encode = StackingSubsampling( + subsampling_factor=subsampling_factor, + feat_in=feat_in, + feat_out=d_model, + norm=True if subsampling == 'stacking_norm' else False, + ) + else: + self.pre_encode = ConvSubsampling( + subsampling=subsampling, + subsampling_factor=subsampling_factor, + feat_in=feat_in, + feat_out=d_model, + conv_channels=subsampling_conv_channels, + subsampling_conv_chunking_factor=subsampling_conv_chunking_factor, + activation=nn.ReLU(True), + is_causal=causal_downsampling, + ) + else: + self.pre_encode = nn.Linear(feat_in, d_model) + + # Reduction + if reduction and reduction_factor > 1: + assert reduction_position >= -1 and reduction_position < n_layers + self.reduction_subsampling = SubsamplingReductionModule( + reduction=reduction, + d_model=d_model, + reduction_factor=reduction_factor, + ) + self.reduction_position = reduction_position + else: + self.reduction_subsampling = None + self.reduction_position = None + + self._feat_out = d_model + + # Biases for relative positional encoding + if not untie_biases and self_attention_model == "rel_pos": + d_head = d_model // n_heads + pos_bias_u = nn.Parameter(torch.Tensor(n_heads, d_head)) + pos_bias_v = nn.Parameter(torch.Tensor(n_heads, d_head)) + nn.init.zeros_(pos_bias_u) + nn.init.zeros_(pos_bias_v) + else: + pos_bias_u = None + pos_bias_v = None + + # Positional encodings + self.pos_emb_max_len = pos_emb_max_len + if self_attention_model == "rel_pos": + self.pos_enc = RelPositionalEncoding( + d_model=d_model, + dropout_rate=dropout_pre_encoder, + max_len=pos_emb_max_len, + xscale=self.xscale, + dropout_rate_emb=dropout_emb, + ) + elif self_attention_model == 'rel_pos_local_attn': + if max(att_context_size) <= 0: + raise ValueError("When using local attention, context size must be set > 0") + self.pos_enc = LocalAttRelPositionalEncoding( + att_context_size=att_context_size, + d_model=d_model, + dropout_rate=dropout, + max_len=pos_emb_max_len, + xscale=self.xscale, + dropout_rate_emb=dropout_emb, + ) + elif self_attention_model == "abs_pos": + pos_bias_u = None + pos_bias_v = None + self.pos_enc = PositionalEncoding( + d_model=d_model, dropout_rate=dropout_pre_encoder, max_len=pos_emb_max_len, xscale=self.xscale + ) + else: + raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") + + self.layers = nn.ModuleList() + for i in range(n_layers): + layer = ASRTransformerLayer( + d_model=d_model, + d_ff=d_ff, + self_attention_model=self_attention_model, + global_tokens=global_tokens, + global_tokens_spacing=global_tokens_spacing, + global_attn_separate=global_attn_separate, + n_heads=n_heads, + dropout=dropout, + dropout_att=dropout_att, + pos_bias_u=pos_bias_u, + pos_bias_v=pos_bias_v, + att_context_size=self.att_context_size, + use_bias=use_bias, + use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + ffn_act=ffn_act, + post_ln=post_ln, + ) + self.layers.append(layer) + + if feat_out > 0 and feat_out != self._feat_out: + self.out_proj = nn.Linear(self._feat_out, feat_out) + self._feat_out = feat_out + else: + self.out_proj = None + self._feat_out = d_model + + self.set_max_audio_length(self.pos_emb_max_len) + self.use_pad_mask = True + + self.setup_streaming_params() + self.export_cache_support = False + + self.layer_drop_probs = compute_stochastic_depth_drop_probs( + len(self.layers), stochastic_depth_drop_prob, stochastic_depth_mode, stochastic_depth_start_layer + ) + # will be set in self.forward() if defined in AccessMixin config + self.interctc_capture_at_layers = None + + def forward_for_export( + self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + ): + if cache_last_channel is not None: + cache_last_channel = cache_last_channel.transpose(0, 1) + cache_last_time = cache_last_time.transpose(0, 1) + + rets = self.forward_internal( + audio_signal, + length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + rets = self.streaming_post_process(rets, keep_all_outputs=False) + if len(rets) == 2: + return rets + elif rets[2] is None and rets[3] is None and rets[4] is None: + return (rets[0], rets[1]) + else: + return ( + rets[0], + rets[1], + rets[2].transpose(0, 1), + rets[3].transpose(0, 1), + rets[4], + ) + + def streaming_post_process(self, rets, keep_all_outputs=True): + if len(rets) == 2: + return rets[0], rets[1], None, None, None + + (encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len) = rets + + if cache_last_channel_next is not None and self.streaming_cfg.last_channel_cache_size >= 0: + if self.streaming_cfg.last_channel_cache_size > 0: + cache_last_channel_next = cache_last_channel_next[ + :, :, -self.streaming_cfg.last_channel_cache_size :, : + ] + + if self.streaming_cfg.valid_out_len > 0 and (not keep_all_outputs or self.att_context_style == "regular"): + encoded = encoded[:, :, : self.streaming_cfg.valid_out_len] + encoded_len = torch.clamp(encoded_len, max=self.streaming_cfg.valid_out_len) + + return (encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len) + + @typecheck() + def forward( + self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + ): + self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) + return self.forward_internal( + audio_signal, + length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + def forward_internal( + self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + ): + if length is None: + length = audio_signal.new_full( + (audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device + ) + + # select a random att_context_size with the distribution specified by att_context_probs during training + # for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size + if self.training and len(self.att_context_size_all) > 1: + cur_att_context_size = random.choices(self.att_context_size_all, weights=self.att_context_probs)[0] + else: + cur_att_context_size = self.att_context_size + + audio_signal = torch.transpose(audio_signal, 1, 2) + + if isinstance(self.pre_encode, nn.Linear): + audio_signal = self.pre_encode(audio_signal) + else: + audio_signal, length = self.pre_encode(x=audio_signal, lengths=length) + length = length.to(torch.int64) + # self.streaming_cfg is set by setup_streaming_cfg(), called in the init + if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None: + audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :] + length = (length - self.streaming_cfg.drop_extra_pre_encoded).clamp(min=0) + + if self.reduction_position is not None and cache_last_channel is not None: + raise ValueError("Caching with reduction feature is not supported yet!") + + max_audio_length = audio_signal.size(1) + if cache_last_channel is not None: + cache_len = self.streaming_cfg.last_channel_cache_size + cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size + max_audio_length = max_audio_length + cache_len + padding_length = length + cache_len + offset = torch.neg(cache_last_channel_len) + cache_len + else: + padding_length = length + cache_last_channel_next = None + cache_len = 0 + offset = None + + audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) + + # Create the self-attention and padding masks + pad_mask, att_mask = self._create_masks( + att_context_size=cur_att_context_size, + padding_length=padding_length, + max_audio_length=max_audio_length, + offset=offset, + device=audio_signal.device, + ) + + if cache_last_channel is not None: + pad_mask = pad_mask[:, cache_len:] + if att_mask is not None: + att_mask = att_mask[:, cache_len:] + # Convert caches from the tensor to list + cache_last_time_next = [] + cache_last_channel_next = [] + + for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)): + original_signal = audio_signal + if cache_last_channel is not None: + cache_last_channel_cur = cache_last_channel[lth] + cache_last_time_cur = cache_last_time[lth] + else: + cache_last_channel_cur = None + cache_last_time_cur = None + audio_signal = layer( + x=audio_signal, + att_mask=att_mask, + pos_emb=pos_emb, + pad_mask=pad_mask, + cache_last_channel=cache_last_channel_cur, + cache_last_time=cache_last_time_cur, + ) + + if cache_last_channel_cur is not None: + (audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal + cache_last_channel_next.append(cache_last_channel_cur) + cache_last_time_next.append(cache_last_time_cur) + + # applying stochastic depth logic from https://arxiv.org/abs/2102.03216 + if self.training and drop_prob > 0.0: + should_drop = torch.rand(1) < drop_prob + # adjusting to match expectation + if should_drop: + # that's not efficient, but it's hard to implement distributed + # version of dropping layers without deadlock or random seed meddling + # so multiplying the signal by 0 to ensure all weights get gradients + audio_signal = audio_signal * 0.0 + original_signal + else: + # not doing this operation if drop prob is 0 as it's identity in that case + audio_signal = (audio_signal - original_signal) / (1.0 - drop_prob) + original_signal + + if self.reduction_position == lth: + audio_signal, length = self.reduction_subsampling(x=audio_signal, lengths=length) + max_audio_length = audio_signal.size(1) + # Don't update the audio_signal here because then it will again scale the audio_signal + # and cause an increase in the WER + _, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) + pad_mask, att_mask = self._create_masks( + att_context_size=cur_att_context_size, + padding_length=length, + max_audio_length=max_audio_length, + offset=offset, + device=audio_signal.device, + ) + + # saving tensors if required for interctc loss + if self.is_access_enabled(getattr(self, "model_guid", None)): + if self.interctc_capture_at_layers is None: + self.interctc_capture_at_layers = self.access_cfg.get('interctc', {}).get('capture_layers', []) + if lth in self.interctc_capture_at_layers: + lth_audio_signal = audio_signal + if self.out_proj is not None: + lth_audio_signal = self.out_proj(audio_signal) + # shape is the same as the shape of audio_signal output, i.e. [B, D, T] + self.register_accessible_tensor( + name=f'interctc/layer_output_{lth}', tensor=torch.transpose(lth_audio_signal, 1, 2) + ) + self.register_accessible_tensor(name=f'interctc/layer_length_{lth}', tensor=length) + + if self.out_proj is not None: + audio_signal = self.out_proj(audio_signal) + + # Reduction + if self.reduction_position == -1: + audio_signal, length = self.reduction_subsampling(x=audio_signal, lengths=length) + + audio_signal = torch.transpose(audio_signal, 1, 2) + length = length.to(dtype=torch.int64) + + if cache_last_channel is not None: + cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0) + cache_last_time_next = torch.stack(cache_last_time_next, dim=0) + return ( + audio_signal, + length, + cache_last_channel_next, + cache_last_time_next, + torch.clamp(cache_last_channel_len + cache_keep_size, max=cache_len), + ) + else: + return audio_signal, length + + def update_max_seq_length(self, seq_length: int, device): + # Find global max audio length across all nodes + if self.sync_max_audio_length and torch.distributed.is_initialized(): + global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) + + # Update across all ranks in the distributed system + torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX) + + seq_length = global_max_len.int().item() + + if seq_length > self.max_audio_length: + self.set_max_audio_length(seq_length) + + def set_max_audio_length(self, max_audio_length): + """ + Sets maximum input length. + Pre-calculates internal seq_range mask. + """ + self.max_audio_length = max_audio_length + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + self.pos_enc.extend_pe(max_audio_length, device, dtype) + + def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device): + if self.self_attention_model != "rel_pos_local_attn": + att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device) + + if self.att_context_style == "regular": + if att_context_size[0] >= 0: + att_mask = att_mask.triu(diagonal=-att_context_size[0]) + if att_context_size[1] >= 0: + att_mask = att_mask.tril(diagonal=att_context_size[1]) + elif self.att_context_style == "chunked_limited": + # When right context is unlimited, just the left side of the masking need to get updated + if att_context_size[1] == -1: + if att_context_size[0] >= 0: + att_mask = att_mask.triu(diagonal=-att_context_size[0]) + else: + chunk_size = att_context_size[1] + 1 + # left_chunks_num specifies the number of chunks to be visible by each chunk on the left side + if att_context_size[0] >= 0: + left_chunks_num = att_context_size[0] // chunk_size + else: + left_chunks_num = 10000 + + chunk_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device) + chunk_idx = torch.div(chunk_idx, chunk_size, rounding_mode="trunc") + diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0) + chunked_limited_mask = torch.logical_and( + torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0) + ) + att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) + else: + att_mask = None + + # pad_mask is the masking to be used to ignore paddings + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(-1) + + if offset is not None: + pad_mask_off = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) >= offset.unsqueeze(-1) + pad_mask = pad_mask_off.logical_and(pad_mask) + + if att_mask is not None: + # pad_mask_for_att_mask is the mask which helps to ignore paddings + pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat([1, max_audio_length, 1]) + pad_mask_for_att_mask = torch.logical_and(pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2)) + # att_mask is the masking to be used by the MHA layers to ignore the tokens not supposed to be visible + att_mask = att_mask[:, :max_audio_length, :max_audio_length] + # paddings should also get ignored, so pad_mask_for_att_mask is used to ignore their corresponding scores + att_mask = torch.logical_and(pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device)) + att_mask = ~att_mask + + pad_mask = ~pad_mask + return pad_mask, att_mask + + def enable_pad_mask(self, on=True): + # On inference, user may choose to disable pad mask + mask = self.use_pad_mask + self.use_pad_mask = on + return mask + + def _calc_context_sizes( + self, att_context_size, att_context_probs, att_context_style, conv_context_size, conv_kernel_size + ): + # convert att_context_size to a standard list of lists + if att_context_size: + att_context_size_all = list(att_context_size) + if isinstance(att_context_size_all[0], int): + att_context_size_all = [att_context_size_all] + for i, att_cs in enumerate(att_context_size_all): + if isinstance(att_cs, ListConfig): + att_context_size_all[i] = list(att_cs) + if att_context_style == "chunked_limited": + if att_cs[0] > 0 and att_cs[0] % (att_cs[1] + 1) > 0: + raise ValueError(f"att_context_size[{i}][0] % (att_context_size[{i}][1] + 1) should be zero!") + if att_cs[1] < 0 and len(att_context_size_all) <= 1: + raise ValueError( + f"Right context (att_context_size[{i}][1]) can not be unlimited for chunked_limited style!" + ) + else: + att_context_size_all = [[-1, -1]] + + if att_context_probs: + if len(att_context_probs) != len(att_context_size_all): + raise ValueError("The size of the att_context_probs should be the same as att_context_size.") + att_context_probs = list(att_context_probs) + if sum(att_context_probs) != 1: + raise ValueError( + "The sum of numbers in att_context_probs should be equal to one to be a distribution." + ) + else: + att_context_probs = [1.0 / len(att_context_size_all)] * len(att_context_size_all) + + if conv_context_size is not None: + if isinstance(conv_context_size, ListConfig): + conv_context_size = list(conv_context_size) + if not isinstance(conv_context_size, list) and not isinstance(conv_context_size, str): + raise ValueError( + f"Invalid conv_context_size! It should be the string 'causal' or a list of two integers." + ) + if conv_context_size == "causal": + conv_context_size = [conv_kernel_size - 1, 0] + else: + if conv_context_size[0] + conv_context_size[1] + 1 != conv_kernel_size: + raise ValueError(f"Invalid conv_context_size: {self.conv_context_size}!") + else: + conv_context_size = [(conv_kernel_size - 1) // 2, (conv_kernel_size - 1) // 2] + return att_context_size_all, att_context_size_all[0], att_context_probs, conv_context_size + + def set_default_att_context_size(self, att_context_size): + if att_context_size not in self.att_context_size_all: + logging.warning( + f"att_context_size={att_context_size} is not among the list of the supported look-aheads: {self.att_context_size_all}" + ) + if att_context_size is not None: + self.att_context_size = att_context_size + + self.setup_streaming_params() + + def setup_streaming_params( + self, + chunk_size: int = None, + shift_size: int = None, + left_chunks: int = None, + att_context_size: list = None, + max_context: int = 10000, + ): + """ + This function sets the needed values and parameters to perform streaming. The configuration would be stored in self.streaming_cfg. + The streaming configuration is needed to simulate streaming inference. + + Args: + chunk_size (int): overrides the chunk size + shift_size (int): overrides the shift size for chunks + left_chunks (int): overrides the number of left chunks visible to each chunk + max_context (int): the value used for the cache size of last_channel layers if left context is set to infinity (-1) + Defaults to -1 (means feat_out is d_model) + """ + streaming_cfg = CacheAwareStreamingConfig() + + # When att_context_size is not specified, it uses the default_att_context_size + if att_context_size is None: + att_context_size = self.att_context_size + + if chunk_size is not None: + if chunk_size < 1: + raise ValueError("chunk_size needs to be a number larger or equal to one.") + lookahead_steps = chunk_size - 1 + streaming_cfg.cache_drop_size = chunk_size - shift_size + elif self.att_context_style == "chunked_limited": + lookahead_steps = att_context_size[1] + streaming_cfg.cache_drop_size = 0 + elif self.att_context_style == "regular": + lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers + streaming_cfg.cache_drop_size = lookahead_steps + else: + streaming_cfg.cache_drop_size = 0 + lookahead_steps = None + + if chunk_size is None: + streaming_cfg.last_channel_cache_size = att_context_size[0] if att_context_size[0] >= 0 else max_context + else: + if left_chunks is None: + raise ValueError("left_chunks can not be None when chunk_size is set.") + streaming_cfg.last_channel_cache_size = left_chunks * chunk_size + + if hasattr(self.pre_encode, "get_sampling_frames"): + sampling_frames = self.pre_encode.get_sampling_frames() + else: + sampling_frames = 0 + + if isinstance(sampling_frames, list): + streaming_cfg.chunk_size = [ + sampling_frames[0] + self.subsampling_factor * lookahead_steps, + sampling_frames[1] + self.subsampling_factor * lookahead_steps, + ] + else: + streaming_cfg.chunk_size = sampling_frames * (1 + lookahead_steps) + + if isinstance(sampling_frames, list): + streaming_cfg.shift_size = [ + sampling_frames[0] + sampling_frames[1] * (lookahead_steps - streaming_cfg.cache_drop_size), + sampling_frames[1] + sampling_frames[1] * (lookahead_steps - streaming_cfg.cache_drop_size), + ] + else: + streaming_cfg.shift_size = sampling_frames * (1 + lookahead_steps - streaming_cfg.cache_drop_size) + + if isinstance(streaming_cfg.shift_size, list): + streaming_cfg.valid_out_len = ( + streaming_cfg.shift_size[1] - sampling_frames[1] + ) // self.subsampling_factor + 1 + else: + streaming_cfg.valid_out_len = streaming_cfg.shift_size // self.subsampling_factor + + if hasattr(self.pre_encode, "get_streaming_cache_size"): + streaming_cfg.pre_encode_cache_size = self.pre_encode.get_streaming_cache_size() + else: + streaming_cfg.pre_encode_cache_size = 0 + + if isinstance(streaming_cfg.pre_encode_cache_size, list): + if streaming_cfg.pre_encode_cache_size[1] >= 1: + streaming_cfg.drop_extra_pre_encoded = ( + 1 + (streaming_cfg.pre_encode_cache_size[1] - 1) // self.subsampling_factor + ) + else: + streaming_cfg.drop_extra_pre_encoded = 0 + else: + streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor + + for m in self.layers.modules(): + if hasattr(m, "_max_cache_len"): + if isinstance(m, MultiHeadAttention): + m.cache_drop_size = streaming_cfg.cache_drop_size + + self.streaming_cfg = streaming_cfg + + def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None, max_dim=0): + if device is None: + device = next(self.parameters()).device + if max_dim > 0: + create_tensor = torch.randn + else: + create_tensor = torch.zeros + last_time_cache_size = self.conv_context_size[0] + cache_last_channel = create_tensor( + ( + len(self.layers), + batch_size, + self.streaming_cfg.last_channel_cache_size, + self.d_model, + ), + device=device, + dtype=dtype, + ) + cache_last_time = create_tensor( + (len(self.layers), batch_size, self.d_model, last_time_cache_size), + device=device, + dtype=dtype, + ) + if max_dim > 0: + cache_last_channel_len = torch.randint( + 0, + min(max_dim, self.streaming_cfg.last_channel_cache_size), + (batch_size,), + device=device, + dtype=torch.int64, + ) + for i in range(batch_size): + cache_last_channel[:, i, cache_last_channel_len[i] :, :] = 0 + # what is the right rule to zero out cache_last_time? + if cache_last_channel_len[i] == 0: + cache_last_time[:, i, :, :] = 0 + else: + cache_last_channel_len = torch.zeros(batch_size, device=device, dtype=torch.int64) + return cache_last_channel, cache_last_time, cache_last_channel_len + + def change_attention_model( + self, + self_attention_model: str = None, + att_context_size: List[int] = None, + update_config: bool = True, + device: torch.device = None, + ): + """ + Update the self_attention_model which changes the positional encoding and attention layers. + + Args: + self_attention_model (str): type of the attention layer and positional encoding + + 'rel_pos': + relative positional embedding and Transformer-XL + + 'rel_pos_local_attn': + relative positional embedding and Transformer-XL with local attention using + overlapping windows. Attention context is determined by att_context_size parameter. + + 'abs_pos': + absolute positional embedding and Transformer + + If None is provided, the self_attention_model isn't changed. Defaults to None. + att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes, + or None to keep as it is. Defaults to None. + update_config (bool): Whether to update the config or not with the new attention model. + Defaults to True. + device (torch.device): If provided, new layers will be moved to the device. + Defaults to None. + """ + + if att_context_size: + att_context_size = list(att_context_size) + else: + att_context_size = self.att_context_size + + if self_attention_model is None: + self_attention_model = self.self_attention_model + + if self_attention_model == 'rel_pos_local_attn' and max(att_context_size) <= 0: + raise ValueError("When using local attention, context size must be set > 0") + + if self_attention_model == "rel_pos": + new_pos_enc = RelPositionalEncoding( + d_model=self._cfg.d_model, + dropout_rate=self._cfg.dropout, + max_len=self._cfg.pos_emb_max_len, + xscale=self.xscale, + dropout_rate_emb=self._cfg.dropout_emb, + ) + elif self_attention_model == 'rel_pos_local_attn': + new_pos_enc = LocalAttRelPositionalEncoding( + att_context_size=att_context_size, + d_model=self._cfg.d_model, + dropout_rate=self._cfg.dropout, + max_len=self._cfg.pos_emb_max_len, + xscale=self.xscale, + dropout_rate_emb=self._cfg.dropout_emb, + ) + elif self_attention_model == "abs_pos": + new_pos_enc = PositionalEncoding( + d_model=self._cfg.d_model, + dropout_rate=self._cfg.dropout, + max_len=self._cfg.pos_emb_max_len, + xscale=self.xscale, + ) + else: + raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") + + if device is not None: + new_pos_enc = new_pos_enc.to(device=device) + del self.pos_enc + self.pos_enc = new_pos_enc + self.self_attention_model = self_attention_model + self.att_context_size = att_context_size + self.set_max_audio_length(self.pos_emb_max_len) + + for name, m in self.named_modules(): + if type(m) == ASRTransformerLayer: + if self_attention_model == 'rel_pos': + new_attn = RelPositionMultiHeadAttention( + n_head=self._cfg.n_heads, + n_feat=self._cfg.d_model, + dropout_rate=self._cfg.dropout_att, + max_cache_len=att_context_size[0], + pos_bias_u=None, + pos_bias_v=None, + use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + ) + elif self_attention_model == 'rel_pos_local_attn': + new_attn = RelPositionMultiHeadAttentionLongformer( + n_head=self._cfg.n_heads, + n_feat=self._cfg.d_model, + dropout_rate=self._cfg.dropout_att, + max_cache_len=att_context_size[0], + att_context_size=att_context_size, + pos_bias_u=None, + pos_bias_v=None, + use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + ) + elif self_attention_model == 'abs_pos': + new_attn = MultiHeadAttention( + n_head=self._cfg.n_heads, + n_feat=self._cfg.d_model, + dropout_rate=self._cfg.dropout_att, + max_cache_len=att_context_size[0], + use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + ) + else: + raise ValueError( + f"'{self_attention_model}' is not not a valid value for 'self_attention_model', " + f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']" + ) + if device is not None: + new_attn = new_attn.to(device=device) + new_attn.load_state_dict(m.self_attn.state_dict(), strict=False) + del m.self_attn + m.self_attn = new_attn + m.self_attention_model = self_attention_model + + if update_config: + with open_dict(self._cfg): + self._cfg.self_attention_model = self_attention_model + self._cfg.att_context_size = att_context_size + + def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int): + """ + Update the conv_chunking_factor (int) + Default is 1 (auto) + Set it to -1 (disabled) or to a specific value (power of 2) if you OOM in the conv subsampling layers + + + Args: + subsampling_conv_chunking_factor (int) + """ + + if not hasattr(self.pre_encode, "change_subsampling_conv_chunking_factor"): + logging.info("Model pre_encoder doesn't have a change_subsampling_conv_chunking_factor method ") + return + + self.pre_encode.change_subsampling_conv_chunking_factor( + subsampling_conv_chunking_factor=subsampling_conv_chunking_factor + ) + + +class ASRTransformerEncoderAdapter(ASRTransformerEncoder, adapter_mixins.AdapterModuleMixin): + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + for layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + layer.add_adapter(name, cfg) + + def is_adapter_available(self) -> bool: + return any([layer.is_adapter_available() for layer in self.layers]) + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + for layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + layer.set_enabled_adapters(name=name, enabled=enabled) + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + for layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + names.update(layer.get_enabled_adapters()) + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) + return cfg + + def get_accepted_adapter_types( + self, + ) -> Set[type]: + types = super().get_accepted_adapter_types() + + if len(types) == 0: + self.set_accepted_adapter_types( + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + adapter_utils.MHA_ADAPTER_CLASSPATH, + adapter_utils.RELMHA_ADAPTER_CLASSPATH, + ] + ) + types = self.get_accepted_adapter_types() + return types + + +class ASRTransformerFeatureExtractor(NeuralModule, Exportable, AccessMixin): + """ + A wrapper module that extracts features from multiple layers of a ASRTransformerEncoder, + by reusing existing mechanisim for interctc loss. + To use it, set `layer_idx_list` to specify the indices of layers to extract from. + Also, you can specify an `aggretator` module to aggregate the features from different layers, default not aggregating. + """ + + def __init__( + self, + encoder: ASRTransformerEncoder, + layer_idx_list: Optional[List[int]] = None, + aggregator: NeuralModule = None, + detach: bool = False, + convert_to_cpu: bool = False, + ): + super().__init__() + self.encoder = encoder + if layer_idx_list is None: + self.layer_idx_list = [i for i in range(len(encoder.layers))] + else: + self.layer_idx_list = [int(l) for l in layer_idx_list] + for x in self.layer_idx_list: + if x < 0 or x >= len(encoder.layers): + raise ValueError(f"layer index {x} out of range [0, {len(encoder.layers)})") + self.enc_access_cfg = { + "interctc": { + "capture_layers": self.layer_idx_list, + }, + "detach": detach, + "convert_to_cpu": convert_to_cpu, + } + self.aggregator = aggregator + + def forward( + self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + ) -> Tuple[torch.Tensor, torch.Tensor]: + old_access_flag = self.is_access_enabled(guid=getattr(self, "model_guid", None)) + self.update_access_cfg(self.enc_access_cfg, guid=getattr(self, "model_guid", None)) + self.set_access_enabled(access_enabled=True, guid=getattr(self, "model_guid", None)) + + _ = self.encoder( + audio_signal=audio_signal, + length=length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + ### chunk of code adapted from ConformerEncoder.forward_internal() + total_registry = {} + for module_registry in self.get_module_registry(self.encoder).values(): + for key in module_registry: + if key.startswith("interctc/") and key in total_registry: + raise RuntimeError(f"layer {key} has been logged multiple times!") + total_registry.update(module_registry) + + encoded_list = [] + encoded_len_list = [] + for layer_idx in self.layer_idx_list: + try: + layer_outputs = total_registry[f"interctc/layer_output_{layer_idx}"] + layer_lengths = total_registry[f"interctc/layer_length_{layer_idx}"] + except KeyError: + raise RuntimeError( + f"Intermediate layer {layer_idx} was not captured! Check the layer index and the number of ConformerEncoder layers." + ) + if len(layer_outputs) > 1 or len(layer_lengths) > 1: + raise RuntimeError("Make sure encoder.forward is called exactly one time") + encoded_list.append(layer_outputs[0]) # [B, D, T] + encoded_len_list.append(layer_lengths[0]) # [B] + + self.encoder.reset_registry() + self.set_access_enabled(access_enabled=old_access_flag, guid=getattr(self, "model_guid", None)) + ### end of adapted chunk + + if self.aggregator is not None: + return self.aggregator(encoded_list, encoded_len_list) # Tensor[B,D*L,T], Tensor[B] + else: + return encoded_list, encoded_len_list # List[Tensor[B,D,T]], List[Tensor[B]] + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(ASRTransformerEncoder) is None: + adapter_mixins.register_adapter(base_class=ASRTransformerEncoder, adapter_class=ASRTransformerEncoderAdapter) + + +@dataclass +class ASRTransformerChangeConfig: + # Change self_attention_model for Conformer + # Options: + # 'rel_pos': relative positional embedding and Transformer-XL + # 'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using + # overlapping chunks. Attention context is determined by att_context_size parameter. + # 'abs_pos': absolute positional embedding and Transformer + # If None is provided, self_attention_model is not changed. + self_attention_model: Optional[str] = None + + # Change the attention context size by providing 2 integers, + # corresponding to left and right context, or -1 for full context. + # If None is provided, the attention context size isn't changed. + att_context_size: Optional[List[int]] = None diff --git a/nemo/collections/asr/modules/ssl_modules/augmentation.py b/nemo/collections/asr/modules/ssl_modules/augmentation.py index cd665634f841..ccc1f8cc947f 100644 --- a/nemo/collections/asr/modules/ssl_modules/augmentation.py +++ b/nemo/collections/asr/modules/ssl_modules/augmentation.py @@ -14,8 +14,6 @@ import math import random -from collections import Counter - import torch from nemo.collections.asr.data.ssl_dataset import AudioNoiseBatch @@ -99,33 +97,38 @@ def __call__(self, batch: AudioNoiseBatch) -> AudioNoiseBatch: # randomly select position to start the mixing mix_start_idx = random.randint(0, audio_lengths[i] - mix_len - 1) - # randomly select the energy ratio between speech and noise + # randomly select the energy ratio and noise source if random.random() < self.noise_ratio or batch_size == 1: energy_ratio = random.uniform(self.min_r_noise, self.max_r_noise) + src = noise[i] + src_len = noise_len[i] else: energy_ratio = random.uniform(self.min_r_speech, self.max_r_speech) - j = random.choice([x for x in range(batch_size) if x != i]) - noise[i] = audio_signal[j].clone() - noise_len[i] = audio_lengths[j] - - # repeat noise to match the length of audio mix length if necessary - if noise_len[i] <= mix_len: - # repeat noise to match the length of audio mix length - noise_start_idx = 0 - noise[i] = self.pad_or_trim_noise(self.repeat_noise(noise[i], noise_len[i], mix_len), max_audio_len) - noise_len[i] = mix_len + j = random.randrange(batch_size - 1) + if j >= i: + j += 1 + src = audio_signal[j] + src_len = audio_lengths[j] + + # extract noise clip directly from source + if src_len <= mix_len: + noise_clip = self.repeat_noise(src, src_len, mix_len) else: - # randomly select a segment of noise - noise_start_idx = random.randint(0, noise_len[i] - mix_len - 1) + noise_start_idx = random.randint(0, src_len - mix_len - 1) + noise_clip = src[noise_start_idx : noise_start_idx + mix_len] # calculate the scale factor for noise - audio_energy = torch.sum(audio_signal[i, : audio_lengths[i]] ** 2) / audio_lengths[i] - noise_energy = torch.sum(noise[i, : noise_len[i]] ** 2) / noise_len[i] if noise_len[i] > 0 else 0 + audio_slice = audio_signal[i, : audio_lengths[i]] + audio_energy = torch.dot(audio_slice, audio_slice) / audio_lengths[i] + if src_len > 0: + src_slice = src[:src_len] + noise_energy = torch.dot(src_slice, src_slice) / src_len + else: + noise_energy = 0 mix_scale = math.sqrt(audio_energy / (10 ** (energy_ratio / 10) * noise_energy)) if noise_energy > 0 else 0 - # get the residual signal to be added to original audio - noise_clip = noise[i, noise_start_idx : noise_start_idx + mix_len] - noise_signal = torch.zeros_like(audio_signal[i]) + # place scaled noise clip into noise signal + noise_signal = torch.zeros(max_audio_len, device=audio_signal.device, dtype=audio_signal.dtype) noise_signal[mix_start_idx : mix_start_idx + mix_len] = mix_scale * noise_clip noise[i] = noise_signal @@ -196,8 +199,14 @@ def __call__(self, batch: AudioNoiseBatch) -> AudioNoiseBatch: num_speakers = random.randint(self.min_num_speakers, self.max_num_speakers) num_speakers = min(num_speakers, batch_size) - # randomly chunk mix_len into num_segments - segment_lens = list(Counter(random.choices(range(num_segments), k=mix_len)).values()) + # randomly chunk mix_len into num_segments using sorted breakpoints + if num_segments == 1 or mix_len <= 1: + segment_lens = [mix_len] + else: + k = min(num_segments - 1, mix_len - 1) + breakpoints = sorted(random.sample(range(1, mix_len), k)) + segment_lens = [b - a for a, b in zip([0] + breakpoints, breakpoints + [mix_len])] + num_segments = len(segment_lens) # randomly select the energy ratio between speech and noise if random.random() < self.noise_ratio or batch_size == 1: @@ -220,8 +229,10 @@ def __call__(self, batch: AudioNoiseBatch) -> AudioNoiseBatch: max_start_idx += segment_lens[j] # calculate the scale factor for noise - audio_energy = torch.sum(audio_signal[i, : audio_lengths[i]] ** 2) / audio_lengths[i] - noise_energy = torch.sum(noise_signal[: audio_lengths[i]] ** 2) / audio_lengths[i] + audio_slice = audio_signal[i, : audio_lengths[i]] + noise_slice = noise_signal[: audio_lengths[i]] + audio_energy = torch.dot(audio_slice, audio_slice) / audio_lengths[i] + noise_energy = torch.dot(noise_slice, noise_slice) / audio_lengths[i] mix_scale = math.sqrt(audio_energy / (10 ** (energy_ratio / 10) * noise_energy)) if noise_energy > 0 else 0 # get the residual signal to be added to original audio @@ -239,21 +250,19 @@ def __call__(self, batch: AudioNoiseBatch) -> AudioNoiseBatch: noisy_audio_len=noise_len, ) - def get_noise_segments(self, batch_idx, batch, segment_lens, num_speakers, mode): + def get_noise_segments(self, batch_idx, batch: AudioNoiseBatch, segment_lens, num_speakers, mode): audio_signal = batch.audio audio_lengths = batch.audio_len noise = batch.noise noise_len = batch.noise_len batch_size = noise.size(0) - max_audio_len = audio_signal.size(1) noise_segments = [] if mode == "noise": - noise_padded = self.pad_or_trim_noise( - self.repeat_noise(noise[batch_idx], noise_len[batch_idx], max_audio_len), max_audio_len - ) + total_len = sum(segment_lens) + noise_repeated = self.repeat_noise(noise[batch_idx], noise_len[batch_idx], total_len) start_idx = 0 for segment_len in segment_lens: - noise_segments.append(noise_padded[start_idx : start_idx + segment_len]) + noise_segments.append(noise_repeated[start_idx : start_idx + segment_len]) start_idx += segment_len return noise_segments diff --git a/nemo/collections/asr/modules/ssl_modules/masking.py b/nemo/collections/asr/modules/ssl_modules/masking.py index c491c56fa829..6e7fff42a8f3 100644 --- a/nemo/collections/asr/modules/ssl_modules/masking.py +++ b/nemo/collections/asr/modules/ssl_modules/masking.py @@ -97,41 +97,36 @@ def forward_without_overlap(self, input_feats: torch.Tensor, input_lengths: torc masked_feats (Tensor): masked features, shape=(batch, features, time) masks (Tensor): the generated masks, shape=(batch, features, time) """ - batch_size = input_feats.size(0) - masks = torch.zeros_like(input_feats) - masked_feats = input_feats - indices = [] - for i in range(batch_size): - if self.block_size >= input_lengths[i] * self.max_mask_ratio: - # handle case where audio is too short + B, D, T = input_feats.shape + device = input_feats.device + + mask = torch.zeros(B, T, dtype=torch.bool, device=device) + + for i in range(B): + length = input_lengths[i] + if self.block_size >= length * self.max_mask_ratio: block_size = 8 num_patches = 1 - patch_indices = torch.tensor([0]) + patch_indices = torch.zeros(1, dtype=torch.long, device=device) offset = 0 else: - num_patches = torch.ceil(input_lengths[i] * self.mask_prob / self.block_size).int() - offset = torch.randint(0, self.block_size, (1,))[0] + num_patches = int(torch.ceil(length * self.mask_prob / self.block_size).item()) + offset = torch.randint(0, self.block_size, (1,)).item() block_size = self.block_size - if (num_patches + 1) * self.block_size > input_lengths[i]: - block_size = torch.div(input_lengths[i], (num_patches + 1), rounding_mode='trunc') - max_num_patches = torch.div(input_lengths[i], block_size, rounding_mode='trunc') - patch_indices = torch.randperm(max_num_patches - 1)[:num_patches] + if (num_patches + 1) * self.block_size > length: + block_size = int(length // (num_patches + 1)) + max_num_patches = int(length // block_size) + patch_indices = torch.randperm(max_num_patches - 1, device=device)[:num_patches] if num_patches: starts = patch_indices * block_size + offset - ends = starts + block_size - positions = torch.cat([torch.arange(s, e) for s, e in zip(starts, ends)]).reshape(-1, 1) - batch_index = torch.full((positions.shape[0], 1), i, dtype=positions.dtype) - positions = torch.cat([batch_index, positions], dim=1) - indices.append(positions.unique(dim=0)) - - if indices: - indices = torch.cat(indices, dim=0).unbind(1) - masks = masks.permute(0, 2, 1) - masked_feats = masked_feats.permute(0, 2, 1) + positions = starts.unsqueeze(1) + torch.arange(block_size, device=device).unsqueeze(0) + positions = positions.reshape(-1).clamp(0, T - 1) + mask[i, positions] = True - masks = masks.index_put(indices, values=torch.tensor(1.0)).permute(0, 2, 1) - masked_feats = masked_feats.index_put(indices, values=self.mask_embedding).permute(0, 2, 1) + mask_3d = mask.unsqueeze(1) + masks = mask_3d.float().expand_as(input_feats) + masked_feats = torch.where(mask_3d, self.mask_embedding.view(1, -1, 1), input_feats) return masked_feats, masks @@ -144,38 +139,35 @@ def forward_with_overlap(self, input_feats: torch.Tensor, input_lengths: torch.T masked_feats (Tensor): masked features, shape=(batch, features, time) masks (Tensor): the generated masks, shape=(batch, features, time) """ - batch_size = input_feats.size(0) - masks = torch.zeros_like(input_feats) - masked_feats = input_feats + B, D, T = input_feats.shape + device = input_feats.device + + mask = torch.zeros(B, T, dtype=torch.bool, device=device) mask_prob = torch.tensor(self.mask_prob) - indices = [] - for i in range(batch_size): + block_offsets = torch.arange(self.block_size, device=device) + + for i in range(B): input_length = input_lengths[i].item() if self.block_size >= input_length * self.max_mask_ratio: - # handle case where audio is too short block_size = 8 num_patches = 1 - patch_indices = torch.tensor([0]) + patch_indices = torch.zeros(1, dtype=torch.long, device=device) else: block_size = self.block_size count = max(0, input_length - self.block_size) num_patches = torch.binomial(torch.tensor(count).float(), mask_prob).long() - patch_indices = torch.randperm(count) - patch_indices = patch_indices[:num_patches] + patch_indices = torch.randperm(count, device=device)[:num_patches] if num_patches: - ends = torch.clamp(patch_indices + block_size, max=input_length) - positions = torch.cat([torch.arange(s, e) for s, e in zip(patch_indices, ends)]).reshape(-1, 1) - batch_index = torch.full((positions.shape[0], 1), i, dtype=positions.dtype) - positions = torch.cat([batch_index, positions], dim=1) - indices.append(positions.unique(dim=0)) - - if indices: - indices = torch.cat(indices, dim=0).unbind(1) - masks = masks.permute(0, 2, 1) - masked_feats = masked_feats.permute(0, 2, 1) - - masks = masks.index_put(indices, values=torch.tensor(1.0)).permute(0, 2, 1) - masked_feats = masked_feats.index_put(indices, values=self.mask_embedding).permute(0, 2, 1) + if block_size == self.block_size: + positions = patch_indices.unsqueeze(1) + block_offsets.unsqueeze(0) + else: + positions = patch_indices.unsqueeze(1) + torch.arange(block_size, device=device).unsqueeze(0) + positions = positions.reshape(-1).clamp(0, T - 1) + mask[i, positions] = True + + mask_3d = mask.unsqueeze(1) + masks = mask_3d.float().expand_as(input_feats) + masked_feats = torch.where(mask_3d, self.mask_embedding.view(1, -1, 1), input_feats) return masked_feats, masks diff --git a/nemo/collections/asr/modules/ssl_modules/quantizers.py b/nemo/collections/asr/modules/ssl_modules/quantizers.py index 8a53a7c00098..8da088f9ea75 100644 --- a/nemo/collections/asr/modules/ssl_modules/quantizers.py +++ b/nemo/collections/asr/modules/ssl_modules/quantizers.py @@ -17,7 +17,7 @@ from torch import nn from nemo.core import NeuralModule -from nemo.core.classes import Exportable, NeuralModule, typecheck +from nemo.core.classes import Exportable, NeuralModule from nemo.core.neural_types import LabelsType, NeuralType, SpectrogramType @@ -35,6 +35,8 @@ def __init__( freeze: bool = True, squeeze_single: bool = False, combine_time_steps: int = 1, + learnable_norm: bool = False, + xavier_normal_init: bool = False, ): """Vector quantization using random projection proposed in BEST-RQ paper: 'Self-Supervised Learning with Random-Projection Quantizer for Speech Recognition' @@ -48,6 +50,8 @@ def __init__( time_ahead: if Ture, the input is of shape (B, T, D), otherwise (B, D, T) freeze: whether to freeze the projection matrix squeeze_single: if True, squeeze codebook dimension if num_books is 1 + learnable_norm: if True, use LayerNorm with learnable affine params; otherwise plain standardization + xavier_normal_init: if True, use Xavier normal initialization for the projection matrix; otherwise Xavier uniform """ super().__init__() @@ -62,18 +66,31 @@ def __init__( self.time_ahead = time_ahead self.squeeze_single = squeeze_single self.combine_time_steps = combine_time_steps + self.input_norm = nn.LayerNorm(self.feat_in, elementwise_affine=learnable_norm) # (B, T, D) -> (B, T, num_books, code_dim) self.proj = nn.Linear(self.feat_in * combine_time_steps, self.num_books * self.code_dim, bias=False) - torch.nn.init.xavier_normal_(self.proj.weight) + if xavier_normal_init: + torch.nn.init.xavier_normal_(self.proj.weight) + else: + torch.nn.init.xavier_uniform_(self.proj.weight) # (num_books, num_classes, hid_dim) - codebooks = torch.randn(self.num_books, self.num_classes, self.code_dim).double() - torch.nn.init.normal_(codebooks, mean=0, std=1) + codebooks = torch.randn(self.num_books, self.num_classes, self.code_dim) codebooks = F.normalize(codebooks, dim=-1) self.codebooks = nn.Parameter(codebooks) + # Pre-computed offset for multi-book embedding lookup: [0, num_classes, 2*num_classes, ...] + self.register_buffer( + 'book_offsets', + self.num_classes * torch.arange(self.num_books).reshape(1, 1, self.num_books), + persistent=False, + ) if freeze: self.freeze() + if learnable_norm: + # unfreeze the layernorm parameters + self.input_norm.weight.requires_grad = True + self.input_norm.bias.requires_grad = True @property def input_types(self): @@ -105,7 +122,6 @@ def output_types(self): "xid": NeuralType(('B', 'T', 'H'), LabelsType()), } - @typecheck() def forward(self, input_signal): """ Args: @@ -120,6 +136,8 @@ def forward(self, input_signal): B, T, _ = input_signal.size() + input_signal = self.input_norm(input_signal) + if self.combine_time_steps > 1: input_signal = input_signal.contiguous().reshape(B, T // self.combine_time_steps, -1) T = T // self.combine_time_steps @@ -127,25 +145,18 @@ def forward(self, input_signal): # (B, T, D) -> (B, T, num_books*code_dim) x = self.proj(input_signal) - # normalize each feature vector + # normalize each projected vector # (B, T, num_books*code_dim) -> (B, T, num_books, code_dim) x = F.normalize(x.view(B, T, self.num_books, self.code_dim), dim=-1) # get tokens (xid) of shape (B, T, num_books) - if self.dist_fn == "cosine": - # (B, T, num_books, code_dim) -> (B, T, num_books, num_classes) - xid = torch.einsum('btdh,dch->btdc', x, self.codebooks) - # (B, T, num_books, num_classes) -> (B, T, num_books) - xid = xid.max(dim=-1)[1] - elif self.dist_fn == "l2": - # (B, T, num_books, code_dim) -> (B, T, num_books, code_dim, num_classes) - xid = x.unsqueeze(-1) - self.codebooks.transpose(1, 2).unsqueeze(0).unsqueeze(0) - xid = xid.norm(dim=-2).argmin(dim=-1) - else: - raise ValueError(f"Unknown distance function {self.dist_fn}, must be one of {self.DIST_FN_LIST}") + # Both x and codebooks are L2-normalized, so for both "cosine" and "l2": + # argmax(dot) == argmax(cosine) == argmin(||a-b||^2) since ||a-b||^2 = 2 - 2*a·b + xid = torch.einsum('btdh,dch->btdc', x, self.codebooks) + xid = xid.max(dim=-1)[1] # xid2: (B, T, num_books) -> (B, T, num_books) - xid2 = xid + self.num_classes * torch.arange(self.num_books, device=xid.device).unsqueeze(0).unsqueeze(0) + xid2 = xid + self.book_offsets # xid2: (B, T, num_books) -> (B*num_books, T) xid2 = xid2.transpose(1, 2).contiguous().view(-1, T) diff --git a/nemo/collections/asr/modules/transformer/transformer_modules.py b/nemo/collections/asr/modules/transformer/transformer_modules.py index 5c45aca92237..1281c461ea5c 100644 --- a/nemo/collections/asr/modules/transformer/transformer_modules.py +++ b/nemo/collections/asr/modules/transformer/transformer_modules.py @@ -18,12 +18,12 @@ import numpy as np import torch +import torch.nn.functional as F from torch import nn -from torch.nn.functional import gelu from nemo.collections.common.parts import form_attention_mask -__all__ = ["TransformerEmbedding", "AttentionBridge"] +__all__ = ["TransformerEmbedding", "AttentionBridge", "PositionWiseFF"] class FixedPositionalEncoding(nn.Module): @@ -231,15 +231,19 @@ class PositionWiseFF(nn.Module): net, usually is (4-8 x hidden_size) in the papers ffn_dropout: probability of dropout applied to net output hidden_act: activation function used between two linear layers + use_bias: whether to use bias in the linear layers """ - def __init__(self, hidden_size, inner_size, ffn_dropout=0.0, hidden_act="relu"): + def __init__(self, hidden_size, inner_size, ffn_dropout=0.0, hidden_act="relu", use_bias=True): super().__init__() - self.dense_in = nn.Linear(hidden_size, inner_size) - self.dense_out = nn.Linear(inner_size, hidden_size) + self.dense_in = nn.Linear(hidden_size, inner_size, bias=use_bias) + self.dense_out = nn.Linear(inner_size, hidden_size, bias=use_bias) self.layer_dropout = nn.Dropout(ffn_dropout) - ACT2FN = {"gelu": gelu, "relu": torch.relu} - self.act_fn = ACT2FN[hidden_act] + if isinstance(hidden_act, str): + if not hasattr(F, hidden_act): + raise ValueError(f"Activation function {hidden_act} not found in torch.nn.functional") + act_fn = getattr(F, hidden_act) + self.act_fn = act_fn def forward(self, hidden_states): output_states = self.dense_in(hidden_states) diff --git a/nemo/collections/asr/parts/submodules/asr_transformer_modules.py b/nemo/collections/asr/parts/submodules/asr_transformer_modules.py new file mode 100644 index 000000000000..f6c25331ae62 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/asr_transformer_modules.py @@ -0,0 +1,279 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from torch import nn as nn +from torch.nn import LayerNorm + +from nemo.collections.asr.modules.transformer.transformer_modules import PositionWiseFF +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.submodules.multi_head_attention import ( + MultiHeadAttention, + RelPositionMultiHeadAttention, + RelPositionMultiHeadAttentionLongformer, +) +from nemo.core.classes.mixins import AccessMixin + +__all__ = ['ASRTransformerLayer'] + + +class ASRTransformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin): + """A single block of the ASR Transformer encoder. + + Args: + d_model (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward + d_ff (int): hidden dimension of PositionwiseFeedForward + self_attention_model (str): type of the attention layer and positional encoding + 'rel_pos': relative positional embedding and Transformer-XL + 'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using + overlapping chunks. Attention context is determined by att_context_size parameter. + 'abs_pos': absolute positional embedding and Transformer + Default is rel_pos. + global_tokens (int): number of tokens to be used for global attention. + Only relevant if self_attention_model is 'rel_pos_local_attn'. + Defaults to 0. + global_tokens_spacing (int): how far apart the global tokens are + Defaults to 1. + global_attn_separate (bool): whether the q, k, v layers used for global tokens should be separate. + Defaults to False. + n_heads (int): number of heads for multi-head attention + dropout (float): dropout probabilities for linear layers + dropout_att (float): dropout probabilities for attention distributions + use_bias (bool): Apply bias to all Linear layers from each ASRTransformerLayer to improve activation flow and stabilize training of huge models. + Defaults to True. + """ + + def __init__( + self, + d_model, + d_ff, + self_attention_model='rel_pos', + global_tokens=0, + global_tokens_spacing=1, + global_attn_separate=False, + n_heads=4, + dropout=0.1, + dropout_att=0.1, + pos_bias_u=None, + pos_bias_v=None, + att_context_size=[-1, -1], + use_bias=True, + use_pytorch_sdpa=False, + use_pytorch_sdpa_backends=None, + ffn_act='relu', + post_ln=False, + ): + super().__init__() + AccessMixin.__init__(self) + + self.use_pytorch_sdpa = use_pytorch_sdpa + if use_pytorch_sdpa_backends is None: + use_pytorch_sdpa_backends = [] + self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends + self.self_attention_model = self_attention_model + self.n_heads = n_heads + self.post_ln = post_ln + + # multi-headed self-attention module + self.norm_self_att = LayerNorm(d_model) + MHA_max_cache_len = att_context_size[0] + + if self_attention_model == 'rel_pos': + self.self_attn = RelPositionMultiHeadAttention( + n_head=n_heads, + n_feat=d_model, + dropout_rate=dropout_att, + pos_bias_u=pos_bias_u, + pos_bias_v=pos_bias_v, + max_cache_len=MHA_max_cache_len, + use_bias=use_bias, + use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + ) + elif self_attention_model == 'rel_pos_local_attn': + self.self_attn = RelPositionMultiHeadAttentionLongformer( + n_head=n_heads, + n_feat=d_model, + dropout_rate=dropout_att, + pos_bias_u=pos_bias_u, + pos_bias_v=pos_bias_v, + max_cache_len=MHA_max_cache_len, + att_context_size=att_context_size, + global_tokens=global_tokens, + global_tokens_spacing=global_tokens_spacing, + global_attn_separate=global_attn_separate, + use_bias=use_bias, + ) + elif self_attention_model == 'abs_pos': + self.self_attn = MultiHeadAttention( + n_head=n_heads, + n_feat=d_model, + dropout_rate=dropout_att, + max_cache_len=MHA_max_cache_len, + use_bias=use_bias, + use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + ) + else: + raise ValueError( + f"'{self_attention_model}' is not not a valid value for 'self_attention_model', " + f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']" + ) + + # second feed forward module + self.norm_feed_forward = LayerNorm(d_model) + self.feed_forward = PositionWiseFF( + hidden_size=d_model, inner_size=d_ff, ffn_dropout=dropout, use_bias=use_bias, hidden_act=ffn_act + ) + + self.dropout = nn.Dropout(dropout) + + def _forward_pre_ln( + self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None + ): + """ + Apply LayerNorm before Self-attention + """ + residual = x + + x = self.norm_self_att(residual) + + if self.self_attention_model == 'rel_pos': + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel) + elif self.self_attention_model == 'rel_pos_local_attn': + x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel) + elif self.self_attention_model == 'abs_pos': + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel) + else: + x = None + + if x is not None and cache_last_channel is not None: + (x, cache_last_channel) = x + + residual = residual + self.dropout(x) + + if self.is_adapter_available(): + # Call the MHA adapters + pack_input = { + 'x': residual, + 'loc': 'mha', + 'att_mask': att_mask, + 'pos_emb': pos_emb, + } + pack_input = self.forward_enabled_adapters(pack_input) + residual = pack_input['x'] + + x = self.norm_feed_forward(residual) + x = self.feed_forward(x) + residual = residual + self.dropout(x) + + x = residual + + if self.is_adapter_available(): + # Call the adapters + pack_input = { + 'x': x, + 'loc': 'post', + } + pack_input = self.forward_enabled_adapters(pack_input) + x = pack_input['x'] + + if self.is_access_enabled(getattr(self, "model_guid", None)) and self.access_cfg.get( + 'save_encoder_tensors', False + ): + self.register_accessible_tensor(name='encoder', tensor=x) + if cache_last_channel is None: + return x + else: + return x, cache_last_channel, cache_last_time + + def _forward_post_ln( + self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None + ): + """ + Apply LayerNorm after Self-attention + """ + residual = x + + if self.self_attention_model == 'rel_pos': + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel) + elif self.self_attention_model == 'rel_pos_local_attn': + x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel) + elif self.self_attention_model == 'abs_pos': + x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel) + else: + x = None + + if x is not None and cache_last_channel is not None: + (x, cache_last_channel) = x + + residual = residual + self.dropout(x) + residual = self.norm_self_att(residual) + + if self.is_adapter_available(): + # Call the MHA adapters + pack_input = { + 'x': residual, + 'loc': 'mha', + 'att_mask': att_mask, + 'pos_emb': pos_emb, + } + pack_input = self.forward_enabled_adapters(pack_input) + residual = pack_input['x'] + + x = residual + x = self.feed_forward(x) + residual = residual + self.dropout(x) + + residual = self.norm_feed_forward(residual) + + x = residual + + if self.is_adapter_available(): + # Call the adapters + pack_input = { + 'x': x, + 'loc': 'post', + } + pack_input = self.forward_enabled_adapters(pack_input) + x = pack_input['x'] + + if self.is_access_enabled(getattr(self, "model_guid", None)) and self.access_cfg.get( + 'save_encoder_tensors', False + ): + self.register_accessible_tensor(name='encoder', tensor=x) + if cache_last_channel is None: + return x + else: + return x, cache_last_channel, cache_last_time + + def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None): + """ + Args: + x (torch.Tensor): input signals (B, T, d_model) + att_mask (torch.Tensor): attention masks(B, T, T) + pos_emb (torch.Tensor): (L, 1, d_model) + pad_mask (torch.tensor): padding mask + cache_last_channel (torch.tensor) : cache for MHA layers (B, T_cache, d_model) + cache_last_time (torch.tensor) : cache for convolutional layers (B, d_model, T_cache) + Returns: + x (torch.Tensor): (B, T, d_model) + cache_last_channel (torch.tensor) : next cache for MHA layers (B, T_cache, d_model) + cache_last_time (torch.tensor) : next cache for convolutional layers (B, d_model, T_cache) + """ + if self.post_ln: + return self._forward_post_ln(x, att_mask, pos_emb, pad_mask, cache_last_channel, cache_last_time) + else: + return self._forward_pre_ln(x, att_mask, pos_emb, pad_mask, cache_last_channel, cache_last_time) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index f7838aa35d62..88c5d2e7e7d8 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -17,12 +17,13 @@ from copy import deepcopy from dataclasses import dataclass from functools import partial -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union import lhotse import numpy as np import torch from lhotse import CutSet, RecordingSet +from lhotse.ais.batch_loader import AISBatchLoader from lhotse.cut import Cut from lhotse.dataset import ( ClippingTransform, @@ -67,6 +68,28 @@ from nemo.utils import logging +class AISBatchedDataPrefetcher: + def __init__(self, iterator: Iterable, buffer_size: int = 10): + self.iterator = iterator + self.get_batch = AISBatchLoader() + self.buffer_size = buffer_size + + def __iter__(self): + buffer = [] + for item in self.iterator: + buffer.append(item) + if len(buffer) < self.buffer_size: + continue + buffer = self.get_batch(buffer) + yield from buffer + buffer = [] + yield from buffer + + +def prefetch_data(cut: Cut) -> Cut: + return cut.move_to_memory() + + @dataclass class LhotseDataLoadingConfig: """ @@ -254,6 +277,8 @@ class LhotseDataLoadingConfig: # The first K examples will actually be read and then discarded, incurring the IO cost, due to # our support of object stores and gzipped files that generally don't have indexes of byte offsets per line. slice_length: Optional[int] = None + ais_batch_prefetch_buffer_size: int = 10 + use_ais_get_batch: bool = False def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: @@ -593,7 +618,12 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No offset_type=config.truncate_offset_type, keep_excessive_supervisions=config.keep_excessive_supervisions, ) + if config.cut_into_windows_duration is not None: + # if config.use_ais_get_batch: + # cuts = CutSet(AISBatchedDataPrefetcher(cuts, buffer_size=config.ais_batch_prefetch_buffer_size)) + # else: + # cuts = cuts.map(prefetch_data) cuts = cuts.cut_into_windows( duration=config.cut_into_windows_duration, hop=config.cut_into_windows_hop, diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index 69ca3d66c041..0e0650475b7d 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -35,7 +35,7 @@ from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator -from lhotse.serialization import open_best +from lhotse.serialization import decode_json_line, open_best from lhotse.utils import compute_num_samples, ifnone from nemo.collections.common.parts.preprocessing.manifest import get_full_path @@ -43,6 +43,25 @@ from nemo.utils.data_utils import is_datastore_path +class LhotseLazyJsonlIterator(LazyJsonlIterator): + def __init__(self, path: str | Path | list[str]): + super().__init__(path) + + def __iter__(self): + tot = 0 + with open_best(self.path, "r") as f: + for line in f: + try: + data = decode_json_line(line) + yield data + tot += 1 + except Exception as e: + logging.error(f"Error decoding JSON line `{line}` in file `{self.path}`: {e}") + raise e + if self._len is None: + self._len = tot + + class LazyNeMoIterator: """ ``LazyNeMoIterator`` reads a NeMo (non-tarred) JSON manifest and converts it on the fly to an ``Iterable[Cut]``. @@ -103,10 +122,10 @@ def __init__( paths = expand_sharded_filepaths(path) if len(paths) == 1: - self.source = LazyJsonlIterator(paths[0]) + self.source = LhotseLazyJsonlIterator(paths[0]) else: self.source = LazyIteratorChain( - *(LazyJsonlIterator(p) for p in paths), shuffle_iters=self.shuffle_shards, seed=self.shard_seed + *(LhotseLazyJsonlIterator(p) for p in paths), shuffle_iters=self.shuffle_shards, seed=self.shard_seed ) self.text_field = text_field self.lang_field = lang_field @@ -123,10 +142,18 @@ def __iter__(self) -> Generator[Cut, None, None]: if data.get("_skipme", False): continue audio_path = get_full_path(str(data.pop("audio_filepath")), str(self.path), force_cache=False) + sampling_rate = data.pop("sampling_rate", None) + if sampling_rate is None: + sampling_rate = data.pop("sample_rate", None) duration = data.pop("duration") offset = data.pop("offset", None) + num_channels = data.pop("num_channels", None) cut = self._create_cut( - audio_path=audio_path, offset=offset, duration=duration, sampling_rate=data.pop("sampling_rate", None) + audio_path=audio_path, + offset=offset, + duration=duration, + sampling_rate=sampling_rate, + num_channels=num_channels, ) # Note that start=0 and not start=offset because supervision's start if relative to the # start of the cut; and cut.start is already set to offset @@ -158,9 +185,12 @@ def _create_cut( offset: float, duration: float, sampling_rate: int | None = None, + num_channels: int | None = None, ) -> Cut: + if num_channels is None: + num_channels = 1 # default to single channel if not specified if not self.metadata_only: - recording = self._create_recording(audio_path, duration, sampling_rate) + recording = self._create_recording(audio_path, duration, sampling_rate, num_channels) cut = recording.to_cut() if offset is not None: cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) @@ -180,7 +210,7 @@ def _create_cut( supervisions=[], recording=Recording( id=audio_path, - sources=[AudioSource(type="dummy", channels=[0], source="")], + sources=[AudioSource(type="dummy", channels=list(range(num_channels)), source="")], sampling_rate=sr, duration=offset + duration, num_samples=compute_num_samples(offset + duration, sr), @@ -193,18 +223,17 @@ def _create_recording( audio_path: str, duration: float, sampling_rate: int | None = None, + num_channels: int = 1, ) -> Recording: if sampling_rate is not None: - # TODO(pzelasko): It will only work with single-channel audio in the current shape. - source_type = "url" if is_datastore_path(audio_path) else "file" return Recording( id=audio_path, - sources=[AudioSource(type=source_type, channels=[0], source=audio_path)], + sources=[AudioSource(type=source_type, channels=list(range(num_channels)), source=audio_path)], sampling_rate=sampling_rate, num_samples=compute_num_samples(duration, sampling_rate), duration=duration, - channel_ids=[0], + channel_ids=list(range(num_channels)), ) else: return Recording.from_file(audio_path) diff --git a/nemo/collections/common/parts/preprocessing/manifest.py b/nemo/collections/common/parts/preprocessing/manifest.py index a71a99be11a5..9ddf078574fe 100644 --- a/nemo/collections/common/parts/preprocessing/manifest.py +++ b/nemo/collections/common/parts/preprocessing/manifest.py @@ -278,8 +278,13 @@ def get_full_path( if data_dir is None: data_dir = os.path.dirname(manifest_file) - # assume audio_file path is relative to data_dir - audio_file_path = os.path.join(data_dir, audio_file) + if not is_datastore_path(audio_file): + # if audio_file is not a datastore path, assume it is relative to data_dir + # assume audio_file path is relative to data_dir + audio_file_path = os.path.join(data_dir, audio_file) + else: + # if audio_file is a datastore path, use the original path + audio_file_path = audio_file if is_datastore_path(audio_file_path): # If audio was originally on an object store, use locally-cached path. diff --git a/nemo/collections/speechlm2/parts/nsight.py b/nemo/collections/speechlm2/parts/nsight.py index 4bb9c168b065..d2dc66aa6e84 100644 --- a/nemo/collections/speechlm2/parts/nsight.py +++ b/nemo/collections/speechlm2/parts/nsight.py @@ -28,7 +28,7 @@ class NsightProfiling(Callback): ... callbacks: - _target_: nemo.collections.speechlm2.parts.nsight.NsightProfiling - start_step: 5 + begin_step: 5 end_step: 10 gen_shape: true nvtx_ranges: true diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 027ca47a4e82..7e2e2caf020e 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1259,7 +1259,9 @@ def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string ) @rank_zero_only - def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: str = 'cpu'): + def maybe_init_from_pretrained_checkpoint( + self, cfg: OmegaConf, map_location: str = 'cpu', weights_only: bool = True + ): """ Initializes a given model with the parameters obtained via specific config arguments. The state dict of the provided model will be updated with `strict=False` setting so as to prevent @@ -1301,6 +1303,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st cfg: The config used to instantiate the model. It need only contain one of the above keys. map_location: str or torch.device() which represents where the intermediate state dict (from the pretrained model or checkpoint) will be loaded. + weights_only: bool flag passed to torch.load(), set to False if fails to load OmegaConf related fields. """ args = [ @@ -1411,7 +1414,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st if isinstance(cfg.init_from_ptl_ckpt, str): # Restore checkpoint ckpt_path = cfg.pop('init_from_ptl_ckpt') - ckpt = torch.load(ckpt_path, map_location=map_location) + ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=weights_only) # Restore checkpoint into current model self.load_state_dict(ckpt['state_dict'], strict=False) @@ -1425,7 +1428,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st for model_load_cfg in model_load_dict.values(): ckpt_path = model_load_cfg.path # Restore model - ckpt = torch.load(ckpt_path, map_location=map_location) + ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=weights_only) include = model_load_cfg.pop('include', [""]) exclude = model_load_cfg.pop('exclude', []) diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index 65e17fda7bdf..71d6459eb1fc 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -357,6 +357,20 @@ def oomptimizer( This may be required in very complex setups where there are additional GPU RAM loads that can't be anticipated through the combination of training_step and optimizer update. """ + click.echo("Starting OOMptimizer...") + click.echo("Arguments:") + click.echo(f" pretrained_name: {pretrained_name}") + click.echo(f" module_name: {module_name}") + click.echo(f" config_path: {config_path}") + click.echo(f" optimizer_name: {optimizer_name}") + click.echo(f" buckets: {buckets}") + click.echo(f" threshold: {threshold}") + click.echo(f" start_batch_size: {start_batch_size}") + click.echo(f" ratio: {ratio}") + click.echo(f" memory_fraction: {memory_fraction}") + click.echo(f" device: {device}") + click.echo(f" dtype: {dtype}") + click.echo(f" ddp: {ddp}") if all(opt is None for opt in (pretrained_name, module_name, config_path)): click.secho( "You need to provide either PRETRAINED_NAME or the pair of MODULE_NAME and CONFIG_PATH.", fg="yellow" diff --git a/tests/collections/asr/test_ssl_modules.py b/tests/collections/asr/test_ssl_modules.py new file mode 100644 index 000000000000..ae74bbf801ee --- /dev/null +++ b/tests/collections/asr/test_ssl_modules.py @@ -0,0 +1,306 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import sys + +import pytest +import torch + +# Import the masking module directly to avoid triggering the full nemo.collections.asr import chain +_spec = importlib.util.spec_from_file_location( + "masking", + "nemo/collections/asr/modules/ssl_modules/masking.py", + submodule_search_locations=[], +) +_mod = importlib.util.module_from_spec(_spec) +# Ensure nemo.core dependencies are available +sys.modules.setdefault("masking", _mod) +_spec.loader.exec_module(_mod) +RandomBlockMasking = _mod.RandomBlockMasking + +# Import the quantizer module +_q_spec = importlib.util.spec_from_file_location( + "quantizers", + "nemo/collections/asr/modules/ssl_modules/quantizers.py", + submodule_search_locations=[], +) +_q_mod = importlib.util.module_from_spec(_q_spec) +sys.modules.setdefault("quantizers", _q_mod) +_q_spec.loader.exec_module(_q_mod) +RandomProjectionVectorQuantizer = _q_mod.RandomProjectionVectorQuantizer + + +class TestRandomBlockMasking: + @pytest.fixture(params=[False, True], ids=["no_overlap", "overlap"]) + def masking_module(self, request): + return RandomBlockMasking( + feat_in=16, + mask_prob=0.5, + block_size=8, + mask_value=0.0, + allow_overlap=request.param, + ) + + def test_output_shapes(self, masking_module): + B, D, T = 4, 16, 100 + feats = torch.randn(B, D, T) + lengths = torch.tensor([100, 80, 60, 100]) + + masked_feats, masks = masking_module(feats, lengths) + + assert masked_feats.shape == (B, D, T) + assert masks.shape == (B, D, T) + + def test_mask_is_binary(self, masking_module): + feats = torch.randn(4, 16, 100) + lengths = torch.tensor([100, 80, 60, 100]) + + _, masks = masking_module(feats, lengths) + + unique_vals = masks.unique() + assert all(v in [0.0, 1.0] for v in unique_vals) + + def test_mask_consistent_across_features(self, masking_module): + """Mask should be the same for all feature dimensions at a given time step.""" + feats = torch.randn(4, 16, 100) + lengths = torch.tensor([100, 80, 60, 100]) + + _, masks = masking_module(feats, lengths) + + # All feature dims should have the same mask pattern + for b in range(4): + first_feat_mask = masks[b, 0, :] + for d in range(1, 16): + assert torch.equal(masks[b, d, :], first_feat_mask) + + def test_masked_positions_get_mask_value(self, masking_module): + """Where mask is 1, features should equal mask_embedding.""" + feats = torch.randn(4, 16, 100) + lengths = torch.tensor([100, 80, 60, 100]) + + masked_feats, masks = masking_module(feats, lengths) + + mask_2d = masks[:, 0, :] # (B, T) - same across features + for b in range(4): + masked_times = mask_2d[b].bool() + if masked_times.any(): + # masked positions should have the mask_embedding value + expected = masking_module.mask_embedding.unsqueeze(1).expand(-1, masked_times.sum()) + actual = masked_feats[b, :, masked_times] + assert torch.allclose(actual, expected) + + def test_unmasked_positions_unchanged(self, masking_module): + """Where mask is 0, features should be unchanged from input.""" + feats = torch.randn(4, 16, 100) + lengths = torch.tensor([100, 80, 60, 100]) + + masked_feats, masks = masking_module(feats, lengths) + + mask_2d = masks[:, 0, :] + for b in range(4): + unmasked_times = ~mask_2d[b].bool() + if unmasked_times.any(): + assert torch.equal(masked_feats[b, :, unmasked_times], feats[b, :, unmasked_times]) + + def test_some_positions_are_masked(self, masking_module): + """With mask_prob=0.5, we should get some masked positions.""" + feats = torch.randn(4, 16, 200) + lengths = torch.tensor([200, 200, 200, 200]) + + _, masks = masking_module(feats, lengths) + + assert masks.sum() > 0 + + def test_short_audio(self): + """Audio shorter than block_size * max_mask_ratio should still work.""" + module = RandomBlockMasking(feat_in=16, block_size=48, mask_value=0.0) + feats = torch.randn(2, 16, 20) + lengths = torch.tensor([20, 15]) + + masked_feats, masks = module(feats, lengths) + + assert masked_feats.shape == feats.shape + assert masks.shape == feats.shape + + def test_batch_size_one(self, masking_module): + feats = torch.randn(1, 16, 100) + lengths = torch.tensor([100]) + + masked_feats, masks = masking_module(feats, lengths) + + assert masked_feats.shape == feats.shape + assert masks.shape == feats.shape + + def test_learnable_mask_embedding(self): + """When mask_value is None, mask_embedding should be learnable.""" + module = RandomBlockMasking(feat_in=16, mask_value=None, freeze=False) + feats = torch.randn(2, 16, 100) + lengths = torch.tensor([100, 80]) + + masked_feats, masks = module(feats, lengths) + + assert masked_feats.shape == feats.shape + assert module.mask_embedding.requires_grad + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_gpu(self, masking_module): + masking_module = masking_module.cuda() + feats = torch.randn(4, 16, 100, device='cuda') + lengths = torch.tensor([100, 80, 60, 100], device='cuda') + + masked_feats, masks = masking_module(feats, lengths) + + assert masked_feats.device.type == 'cuda' + assert masks.device.type == 'cuda' + + +class TestRandomProjectionVectorQuantizer: + @pytest.fixture(params=["cosine", "l2"]) + def quantizer(self, request): + return RandomProjectionVectorQuantizer( + feat_in=32, + code_dim=8, + num_classes=64, + num_books=2, + dist_fn=request.param, + ) + + def test_output_shapes(self, quantizer): + B, D, T = 2, 32, 10 + x = torch.randn(B, D, T) + xq, xid = quantizer(input_signal=x) + assert xq.shape == (B, 8, T, 2) # (B, code_dim, T, num_books) + assert xid.shape == (B, T, 2) # (B, T, num_books) + + def test_output_shapes_time_ahead(self): + q = RandomProjectionVectorQuantizer( + feat_in=32, + code_dim=8, + num_classes=64, + num_books=2, + dist_fn="l2", + time_ahead=True, + ) + B, T, D = 2, 10, 32 + x = torch.randn(B, T, D) + xq, xid = q(input_signal=x) + assert xq.shape == (B, T, 8, 2) # (B, T, code_dim, num_books) + assert xid.shape == (B, T, 2) + + def test_squeeze_single(self): + q = RandomProjectionVectorQuantizer( + feat_in=32, + code_dim=8, + num_classes=64, + num_books=1, + dist_fn="cosine", + squeeze_single=True, + ) + x = torch.randn(2, 32, 10) + xq, xid = q(input_signal=x) + assert xq.shape == (2, 8, 10) # squeezed book dim + assert xid.shape == (2, 10) + + def test_combine_time_steps(self): + q = RandomProjectionVectorQuantizer( + feat_in=16, + code_dim=8, + num_classes=64, + num_books=1, + dist_fn="l2", + combine_time_steps=2, + ) + x = torch.randn(2, 16, 10) # T=10, will become T=5 + xq, xid = q(input_signal=x) + assert xq.shape == (2, 8, 5, 1) + assert xid.shape == (2, 5, 1) + + def test_codebooks_are_float32(self): + q = RandomProjectionVectorQuantizer(feat_in=16, code_dim=8, num_classes=32, num_books=1) + assert q.codebooks.dtype == torch.float32 + + def test_l2_nearest_neighbor_correctness(self): + """Verify L2 picks the closest codebook entry on a small example.""" + q = RandomProjectionVectorQuantizer( + feat_in=4, + code_dim=4, + num_classes=3, + num_books=1, + dist_fn="l2", + time_ahead=True, + squeeze_single=True, + ) + # Set projection to identity so x passes through unchanged + with torch.no_grad(): + q.proj.weight.copy_(torch.eye(4)) + # Set codebook to known vectors (already normalized) + cb = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + ] + ).unsqueeze( + 0 + ) # (1, 3, 4) + q.codebooks.copy_(cb) + + # Input close to codebook entry 2 (z-axis) + inp = torch.tensor([[[0.05, 0.05, 0.9, 0.0]]]) # (1, 1, 4) + _, xid = q(input_signal=inp) + assert xid.item() == 2 + + # Input close to codebook entry 0 (x-axis) + inp = torch.tensor([[[0.9, 0.1, 0.0, 0.0]]]) + _, xid = q(input_signal=inp) + assert xid.item() == 0 + + def test_cosine_nearest_neighbor_correctness(self): + """Verify cosine picks the most similar codebook entry.""" + q = RandomProjectionVectorQuantizer( + feat_in=4, + code_dim=4, + num_classes=3, + num_books=1, + dist_fn="cosine", + time_ahead=True, + squeeze_single=True, + ) + with torch.no_grad(): + q.proj.weight.copy_(torch.eye(4)) + cb = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + ] + ).unsqueeze(0) + q.codebooks.copy_(cb) + + inp = torch.tensor([[[0.05, 0.05, 0.9, 0.0]]]) + _, xid = q(input_signal=inp) + assert xid.item() == 2 + + inp = torch.tensor([[[0.9, 0.1, 0.0, 0.0]]]) + _, xid = q(input_signal=inp) + assert xid.item() == 0 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_gpu(self, quantizer): + quantizer = quantizer.cuda() + x = torch.randn(2, 32, 10, device='cuda') + xq, xid = quantizer(input_signal=x) + assert xq.device.type == 'cuda' + assert xid.device.type == 'cuda'