Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dithering to the Speech2TextFeatureExtractor API. #34638

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/transformers/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def spectrogram(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
Expand Down Expand Up @@ -460,6 +461,9 @@ def spectrogram(
onesided (`bool`, *optional*, defaults to `True`):
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
dither (`float`):
Add dithering (add small Gaussian noise to each frame).
E.g. use 4 to add dithering, 0.0 means no dithering.
preemphasis (`float`, *optional*)
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
Expand Down Expand Up @@ -540,6 +544,9 @@ def spectrogram(
for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length]

if dither != 0.0:
buffer[:frame_length] += dither * np.random.randn(*buffer[:frame_length].shape)

if remove_dc_offset:
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()

Expand Down Expand Up @@ -591,6 +598,7 @@ def spectrogram_batch(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
Expand Down Expand Up @@ -653,6 +661,9 @@ def spectrogram_batch(
The padding strategy when `center` is `True`.
onesided (`bool`, *optional*, defaults to `True`):
If True, returns a one-sided spectrogram for real input signals.
dither (`float`):
Add dithering (add small Gaussian noise to each frame).
E.g. use 4 to add dithering, 0.0 means no dithering.
preemphasis (`float`, *optional*):
Applies a pre-emphasis filter to each frame.
mel_filters (`np.ndarray`, *optional*):
Expand Down Expand Up @@ -745,6 +756,9 @@ def spectrogram_batch(
timestep = frame_idx * hop_length
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]

if dither != 0.0:
buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)

if remove_dc_offset:
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
Number of Mel-frequency bins.
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used to fill the padding vectors.
dither (`float`, *optional*, defaults to 0.0):
Add dithering (add small Gaussian noise to each frame).
E.g. use 4 to add dithering, 0.0 means no dithering.
do_ceptral_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
normalize_means (`bool`, *optional*, defaults to `True`):
Expand All @@ -68,13 +71,15 @@ def __init__(
sampling_rate=16000,
num_mel_bins=80,
padding_value=0.0,
dither=0.0,
do_ceptral_normalize=True,
normalize_means=True,
normalize_vars=True,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins
self.dither = dither
self.do_ceptral_normalize = do_ceptral_normalize
self.normalize_means = normalize_means
self.normalize_vars = normalize_vars
Expand Down Expand Up @@ -106,7 +111,12 @@ def _extract_fbank_features(
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
features = ta_kaldi.fbank(
waveform,
dither=self.dither,
num_mel_bins=self.num_mel_bins,
sample_frequency=self.sampling_rate,
)
features = features.numpy()
else:
waveform = np.squeeze(waveform)
Expand All @@ -118,6 +128,7 @@ def _extract_fbank_features(
fft_length=512,
power=2.0,
center=False,
dither=self.dither,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
Expand Down