From 668bf55395534019ed041104bad593c6636815ca Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Thu, 7 Nov 2024 13:38:38 +0100 Subject: [PATCH] Add dithering to the `Speech2TextFeatureExtractor` API. - in kaldi : https://github.com/kaldi-asr/kaldi/blob/4a8b7f673275597fef8a15b160124bd0985b59bd/src/feat/feature-window.cc#L145 - with dithering without a seed, the features become non-deterministic due to small Gaussian noise added to the audio (i.e. 2 runs lead to little different outputs) --- src/transformers/audio_utils.py | 14 ++++++++++++++ .../feature_extraction_speech_to_text.py | 13 ++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index d46b0eb62e0e7e..b5c7c53edb28a6 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -390,6 +390,7 @@ def spectrogram( center: bool = True, pad_mode: str = "reflect", onesided: bool = True, + dither: float = 0.0, preemphasis: Optional[float] = None, mel_filters: Optional[np.ndarray] = None, mel_floor: float = 1e-10, @@ -460,6 +461,9 @@ def spectrogram( onesided (`bool`, *optional*, defaults to `True`): If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1` frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins. + dither (`float`): + Add dithering (add small Gaussian noise to each frame). + E.g. use 4 to add dithering, 0.0 means no dithering. preemphasis (`float`, *optional*) Coefficient for a low-pass filter that applies pre-emphasis before the DFT. mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*): @@ -540,6 +544,9 @@ def spectrogram( for frame_idx in range(num_frames): buffer[:frame_length] = waveform[timestep : timestep + frame_length] + if dither != 0.0: + buffer[:frame_length] += dither * np.random.randn(*buffer[:frame_length].shape) + if remove_dc_offset: buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean() @@ -591,6 +598,7 @@ def spectrogram_batch( center: bool = True, pad_mode: str = "reflect", onesided: bool = True, + dither: float = 0.0, preemphasis: Optional[float] = None, mel_filters: Optional[np.ndarray] = None, mel_floor: float = 1e-10, @@ -653,6 +661,9 @@ def spectrogram_batch( The padding strategy when `center` is `True`. onesided (`bool`, *optional*, defaults to `True`): If True, returns a one-sided spectrogram for real input signals. + dither (`float`): + Add dithering (add small Gaussian noise to each frame). + E.g. use 4 to add dithering, 0.0 means no dithering. preemphasis (`float`, *optional*): Applies a pre-emphasis filter to each frame. mel_filters (`np.ndarray`, *optional*): @@ -745,6 +756,9 @@ def spectrogram_batch( timestep = frame_idx * hop_length buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length] + if dither != 0.0: + buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape) + if remove_dc_offset: buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True) diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py index b8a2b6bfb29738..e0e89da8431a46 100644 --- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -52,6 +52,9 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): Number of Mel-frequency bins. padding_value (`float`, *optional*, defaults to 0.0): The value that is used to fill the padding vectors. + dither (`float`, *optional*, defaults to 0.0): + Add dithering (add small Gaussian noise to each frame). + E.g. use 4 to add dithering, 0.0 means no dithering. do_ceptral_normalize (`bool`, *optional*, defaults to `True`): Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features. normalize_means (`bool`, *optional*, defaults to `True`): @@ -68,6 +71,7 @@ def __init__( sampling_rate=16000, num_mel_bins=80, padding_value=0.0, + dither=0.0, do_ceptral_normalize=True, normalize_means=True, normalize_vars=True, @@ -75,6 +79,7 @@ def __init__( ): super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) self.num_mel_bins = num_mel_bins + self.dither = dither self.do_ceptral_normalize = do_ceptral_normalize self.normalize_means = normalize_means self.normalize_vars = normalize_vars @@ -106,7 +111,12 @@ def _extract_fbank_features( waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers if is_speech_available(): waveform = torch.from_numpy(waveform).unsqueeze(0) - features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate) + features = ta_kaldi.fbank( + waveform, + dither=self.dither, + num_mel_bins=self.num_mel_bins, + sample_frequency=self.sampling_rate, + ) features = features.numpy() else: waveform = np.squeeze(waveform) @@ -118,6 +128,7 @@ def _extract_fbank_features( fft_length=512, power=2.0, center=False, + dither=self.dither, preemphasis=0.97, mel_filters=self.mel_filters, log_mel="log",