Skip to content

Commit

Permalink
Replace wav intermediate processing format to mp3
Browse files Browse the repository at this point in the history
  • Loading branch information
AliOsm committed Jun 26, 2024
1 parent 793214b commit 9a4a03e
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 86 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ wit = [
"numpy>=1.26.4",
"pydub>=0.25.1",
"requests>=2.32.0",
"scipy>=1.13.0",
]
whisper = [
"faster-whisper>=1.0.2",
Expand Down
115 changes: 43 additions & 72 deletions tafrigh/audio_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import numpy as np

from auditok import AudioRegion
from auditok.core import split
from scipy.io import wavfile
from pydub import AudioSegment
from pydub.generators import WhiteNoise


class AudioSplitter:
Expand All @@ -18,92 +20,61 @@ def split(
energy_threshold: float = 50,
expand_segments_with_noise: bool = False,
noise_seconds: int = 1,
noise_amplitude: int = 10,
noise_amplitude: int = 0,
) -> list[tuple[str, float, float]]:
sampling_rate, data = self._read_audio(file_path)
temp_file_name = self._write_temp_audio(sampling_rate, data)
segments = self._split_audio(temp_file_name, min_dur, max_dur, max_silence, energy_threshold)

os.remove(temp_file_name)

if expand_segments_with_noise:
expanded_segments = self._expand_segments_with_noise(
segments,
noise_seconds,
noise_amplitude,
sampling_rate,
data.dtype,
)
else:
expanded_segments = [(segment, segment.meta.start, segment.meta.end) for segment in segments]

return self._save_segments(output_dir, sampling_rate, expanded_segments)

def _read_audio(self, file_path: str) -> tuple[int, np.ndarray]:
sampling_rate, data = wavfile.read(file_path)

if len(data.shape) > 1 and data.shape[1] > 1:
data = np.mean(data, axis=1)

return sampling_rate, data

def _write_audio(self, file_path: str, sampling_rate: int, data: np.ndarray) -> None:
wavfile.write(file_path, sampling_rate, data.astype(np.int16))

def _write_temp_audio(self, sampling_rate: int, data: np.ndarray) -> str:
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
temp_file_name = temp_file.name
self._write_audio(temp_file_name, sampling_rate, data)

return temp_file_name

def _split_audio(
self,
temp_file_name: str,
min_dur: float,
max_dur: float,
max_silence: float,
energy_threshold: float,
):
return split(
temp_file_name,
segments = split(
file_path,
min_dur=min_dur,
max_dur=max_dur,
max_silence=max_silence,
energy_threshold=energy_threshold,
)

def _expand_segments_with_noise(
if expand_segments_with_noise:
segments = [
(
self._expand_segment_with_noise(segment, noise_seconds, noise_amplitude),
segment.meta.start,
segment.meta.end,
) for segment in segments
]

return self._save_segments(output_dir, segments)

def _expand_segment_with_noise(
self,
segments: list,
segment: AudioRegion,
noise_seconds: int,
noise_amplitude: int,
sampling_rate: int,
dtype: np.dtype,
) -> list[tuple[np.ndarray, float, float]]:
expanded_segments = []
) -> AudioSegment:

for segment in segments:
# Have different noise in the beginning and the end gave us better results :).
prepend_noise = np.random.normal(0, noise_amplitude, int(noise_seconds * sampling_rate)).astype(dtype)
append_noise = np.random.normal(0, noise_amplitude, int(noise_seconds * sampling_rate)).astype(dtype)
audio_segment = AudioSegment(
segment._data,
frame_rate=segment.sampling_rate,
sample_width=segment.sample_width,
channels=segment.channels,
)

expanded_segment = np.concatenate((prepend_noise, segment, append_noise))
expanded_segments.append((expanded_segment, segment.meta.start, segment.meta.end))
pre_noise = WhiteNoise().to_audio_segment(duration=noise_seconds * 1000, volume=noise_amplitude)
post_noise = WhiteNoise().to_audio_segment(duration=noise_seconds * 1000, volume=noise_amplitude)

return expanded_segments
return pre_noise + audio_segment + post_noise

def _save_segments(
self,
output_dir: str,
sampling_rate: int,
expanded_segments: list[tuple[np.ndarray, float, float]],
segments: list[AudioSegment | tuple[AudioSegment, float, float]],
) -> list[tuple[str, float, float]]:
segments = []

for i, (expanded_segment, start, end) in enumerate(expanded_segments):
output_file = os.path.join(output_dir, f"segment_{i + 1}.wav")
self._write_audio(output_file, sampling_rate, expanded_segment)
segments.append((output_file, start, end))

return segments
segment_paths = []

for i, segment in enumerate(segments):
output_file = os.path.join(output_dir, f'segment_{i + 1}.mp3')

if isinstance(segment, tuple):
segment[0].export(output_file, format='mp3')
segment_paths.append((output_file, segment[1], segment[2]))
else:
segment.save(output_file)
segment_paths.append((output_file, segment.meta.start, segment.meta.end))

return segment_paths
10 changes: 5 additions & 5 deletions tafrigh/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def process_local(
file_path = str(file['file_path'].absolute())

if config.use_wit():
wav_file_path = str(wit_file_utils.convert_to_wav(file['file_path']).absolute())
recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(wav_file_path, config.wit)
mp3_file_path = str(wit_file_utils.convert_to_mp3(file['file_path']).absolute())
recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(mp3_file_path, config.wit)
else:
recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize(
file_path,
Expand All @@ -165,8 +165,8 @@ def process_local(
segments = exception.value
break

if config.use_wit() and file['file_path'].suffix != '.wav':
Path(wav_file_path).unlink(missing_ok=True)
if config.use_wit() and file['file_path'].suffix != '.mp3':
Path(mp3_file_path).unlink(missing_ok=True)

writer.write_all(Path(file['file_name']).stem, segments, config.output)

Expand Down Expand Up @@ -218,7 +218,7 @@ def process_url(

continue

file_path = os.path.join(config.output.output_dir, f"{element['id']}.wav")
file_path = os.path.join(config.output.output_dir, f"{element['id']}.mp3")

if config.use_wit():
recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(file_path, config.wit)
Expand Down
5 changes: 3 additions & 2 deletions tafrigh/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ def _config(self, download_archive: Union[str, bool]) -> dict[str, Any]:
return {
'quiet': True,
'verbose': False,
'format': 'wav/bestaudio/best',
'format': 'bestaudio',
'extract_audio': True,
'outtmpl': os.path.join(self.output_dir, '%(id)s.%(ext)s'),
'ignoreerrors': True,
'download_archive': download_archive,
'playlist_items': self.playlist_items,
'postprocessors': [
{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'wav',
'preferredcodec': 'mp3',
},
],
}
Expand Down
6 changes: 3 additions & 3 deletions tafrigh/recognizers/wit_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def _process_segment(

segment_file_path, start, end = segment

with open(segment_file_path, 'rb') as wav_file:
audio_content = wav_file.read()
with open(segment_file_path, 'rb') as mp3_file:
audio_content = mp3_file.read()

retries = 5

Expand All @@ -127,7 +127,7 @@ def _process_segment(
'https://api.wit.ai/speech',
headers={
'Accept': 'application/vnd.wit.20200513+json',
'Content-Type': 'audio/wav',
'Content-Type': 'audio/mpeg3',
'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}',
},
data=audio_content,
Expand Down
6 changes: 3 additions & 3 deletions tafrigh/utils/wit/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from pydub import AudioSegment


def convert_to_wav(file: Path) -> Path:
def convert_to_mp3(file: Path) -> Path:
audio_file = AudioSegment.from_file(str(file))
converted_file_path = file.with_suffix('.wav')
audio_file.export(str(converted_file_path), format='wav')
converted_file_path = file.with_suffix('.mp3')
audio_file.export(str(converted_file_path), format='mp3')
return converted_file_path

0 comments on commit 9a4a03e

Please sign in to comment.