Skip to content

Commit

Permalink
Add dithering to the Speech2TextFeatureExtractor API.
Browse files Browse the repository at this point in the history
- in kaldi : https://github.com/kaldi-asr/kaldi/blob/4a8b7f673275597fef8a15b160124bd0985b59bd/src/feat/feature-window.cc#L145
- with dithering without a seed, the features become non-deterministic due
  to small Gaussian noise added to the audio (i.e. 2 runs lead to little
  different outputs)
  • Loading branch information
KarelVesely84 committed Nov 7, 2024
1 parent 7bbc624 commit b7cb796
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
24 changes: 24 additions & 0 deletions src/transformers/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import warnings
from typing import List, Optional, Tuple, Union

import torch

import numpy as np


Expand Down Expand Up @@ -390,6 +392,7 @@ def spectrogram(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
Expand Down Expand Up @@ -460,6 +463,9 @@ def spectrogram(
onesided (`bool`, *optional*, defaults to `True`):
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
dither (`float`):
Add dithering (add small Gaussian noise to each frame).
E.g. use 4 to add dithering, 0.0 means no dithering.
preemphasis (`float`, *optional*)
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
Expand Down Expand Up @@ -540,6 +546,13 @@ def spectrogram(
for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length]

if dither != 0.0:
buffer[:frame_length] += dither * torch.randn(
buffer[:frame_length].shape,
device=buffer.device,
dtype=buffer.dtype,
)

if remove_dc_offset:
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()

Expand Down Expand Up @@ -591,6 +604,7 @@ def spectrogram_batch(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
Expand Down Expand Up @@ -653,6 +667,9 @@ def spectrogram_batch(
The padding strategy when `center` is `True`.
onesided (`bool`, *optional*, defaults to `True`):
If True, returns a one-sided spectrogram for real input signals.
dither (`float`):
Add dithering (add small Gaussian noise to each frame).
E.g. use 4 to add dithering, 0.0 means no dithering.
preemphasis (`float`, *optional*):
Applies a pre-emphasis filter to each frame.
mel_filters (`np.ndarray`, *optional*):
Expand Down Expand Up @@ -745,6 +762,13 @@ def spectrogram_batch(
timestep = frame_idx * hop_length
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]

if dither != 0.0:
buffer[:, :frame_length] += dither * torch.randn(
buffer[:, :frame_length].shape,
device=buffer.device,
dtype=buffer.dtype,
)

if remove_dc_offset:
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
Number of Mel-frequency bins.
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used to fill the padding vectors.
dither (`float`, *optional*, defaults to 0.0):
Add dithering (add small Gaussian noise to each frame).
E.g. use 4 to add dithering, 0.0 means no dithering.
do_ceptral_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
normalize_means (`bool`, *optional*, defaults to `True`):
Expand All @@ -68,13 +71,15 @@ def __init__(
sampling_rate=16000,
num_mel_bins=80,
padding_value=0.0,
dither=0.0,
do_ceptral_normalize=True,
normalize_means=True,
normalize_vars=True,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins
self.dither = dither
self.do_ceptral_normalize = do_ceptral_normalize
self.normalize_means = normalize_means
self.normalize_vars = normalize_vars
Expand Down Expand Up @@ -106,7 +111,12 @@ def _extract_fbank_features(
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
features = ta_kaldi.fbank(
waveform,
dither=self.dither,
num_mel_bins=self.num_mel_bins,
sample_frequency=self.sampling_rate,
)
features = features.numpy()
else:
waveform = np.squeeze(waveform)
Expand All @@ -118,6 +128,7 @@ def _extract_fbank_features(
fft_length=512,
power=2.0,
center=False,
dither=self.dither,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
Expand Down

0 comments on commit b7cb796

Please sign in to comment.