From 028a390c2b3f317b9b8b05243d194ce7ce54d0ca Mon Sep 17 00:00:00 2001 From: Ali Hamdi Ali Fadel Date: Wed, 26 Jun 2024 21:54:49 +0000 Subject: [PATCH] Remove tafrigh folder --- tafrigh/__init__.py | 17 -- tafrigh/audio_splitter.py | 80 ------ tafrigh/cli.py | 264 -------------------- tafrigh/config.py | 118 --------- tafrigh/downloader.py | 58 ----- tafrigh/recognizers/__init__.py | 0 tafrigh/recognizers/whisper_recognizer.py | 104 -------- tafrigh/recognizers/wit_calling_throttle.py | 38 --- tafrigh/recognizers/wit_recognizer.py | 158 ------------ tafrigh/types/__init__.py | 0 tafrigh/types/transcript_type.py | 15 -- tafrigh/types/whisper/__init__.py | 0 tafrigh/types/whisper/type_hints.py | 11 - tafrigh/utils/__init__.py | 0 tafrigh/utils/cli_utils.py | 200 --------------- tafrigh/utils/file_utils.py | 20 -- tafrigh/utils/time_utils.py | 15 -- tafrigh/utils/whisper/__init__.py | 0 tafrigh/utils/whisper/whisper_utils.py | 15 -- tafrigh/utils/wit/__init__.py | 0 tafrigh/utils/wit/file_utils.py | 10 - tafrigh/writer.py | 162 ------------ 22 files changed, 1285 deletions(-) delete mode 100644 tafrigh/__init__.py delete mode 100644 tafrigh/audio_splitter.py delete mode 100644 tafrigh/cli.py delete mode 100644 tafrigh/config.py delete mode 100644 tafrigh/downloader.py delete mode 100644 tafrigh/recognizers/__init__.py delete mode 100644 tafrigh/recognizers/whisper_recognizer.py delete mode 100644 tafrigh/recognizers/wit_calling_throttle.py delete mode 100644 tafrigh/recognizers/wit_recognizer.py delete mode 100644 tafrigh/types/__init__.py delete mode 100644 tafrigh/types/transcript_type.py delete mode 100644 tafrigh/types/whisper/__init__.py delete mode 100644 tafrigh/types/whisper/type_hints.py delete mode 100644 tafrigh/utils/__init__.py delete mode 100644 tafrigh/utils/cli_utils.py delete mode 100644 tafrigh/utils/file_utils.py delete mode 100644 tafrigh/utils/time_utils.py delete mode 100644 tafrigh/utils/whisper/__init__.py delete mode 100644 tafrigh/utils/whisper/whisper_utils.py delete mode 100644 tafrigh/utils/wit/__init__.py delete mode 100644 tafrigh/utils/wit/file_utils.py delete mode 100644 tafrigh/writer.py diff --git a/tafrigh/__init__.py b/tafrigh/__init__.py deleted file mode 100644 index c47c637..0000000 --- a/tafrigh/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from tafrigh.cli import farrigh -from tafrigh.config import Config -from tafrigh.downloader import Downloader -from tafrigh.types.transcript_type import TranscriptType -from tafrigh.writer import Writer - - -try: - from tafrigh.recognizers.whisper_recognizer import WhisperRecognizer -except ModuleNotFoundError: - pass - -try: - from tafrigh.audio_splitter import AudioSplitter - from tafrigh.recognizers.wit_recognizer import WitRecognizer -except ModuleNotFoundError: - pass diff --git a/tafrigh/audio_splitter.py b/tafrigh/audio_splitter.py deleted file mode 100644 index c98c282..0000000 --- a/tafrigh/audio_splitter.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import tempfile - -import numpy as np - -from auditok import AudioRegion -from auditok.core import split -from pydub import AudioSegment -from pydub.generators import WhiteNoise - - -class AudioSplitter: - def split( - self, - file_path: str, - output_dir: str, - min_dur: float = 0.5, - max_dur: float = 15, - max_silence: float = 0.5, - energy_threshold: float = 50, - expand_segments_with_noise: bool = False, - noise_seconds: int = 1, - noise_amplitude: int = 0, - ) -> list[tuple[str, float, float]]: - segments = split( - file_path, - min_dur=min_dur, - max_dur=max_dur, - max_silence=max_silence, - energy_threshold=energy_threshold, - ) - - if expand_segments_with_noise: - segments = [ - ( - self._expand_segment_with_noise(segment, noise_seconds, noise_amplitude), - segment.meta.start, - segment.meta.end, - ) for segment in segments - ] - - return self._save_segments(output_dir, segments) - - def _expand_segment_with_noise( - self, - segment: AudioRegion, - noise_seconds: int, - noise_amplitude: int, - ) -> AudioSegment: - - audio_segment = AudioSegment( - segment._data, - frame_rate=segment.sampling_rate, - sample_width=segment.sample_width, - channels=segment.channels, - ) - - pre_noise = WhiteNoise().to_audio_segment(duration=noise_seconds * 1000, volume=noise_amplitude) - post_noise = WhiteNoise().to_audio_segment(duration=noise_seconds * 1000, volume=noise_amplitude) - - return pre_noise + audio_segment + post_noise - - def _save_segments( - self, - output_dir: str, - segments: list[AudioSegment | tuple[AudioSegment, float, float]], - ) -> list[tuple[str, float, float]]: - segment_paths = [] - - for i, segment in enumerate(segments): - output_file = os.path.join(output_dir, f'segment_{i + 1}.mp3') - - if isinstance(segment, tuple): - segment[0].export(output_file, format='mp3') - segment_paths.append((output_file, segment[1], segment[2])) - else: - segment.save(output_file) - segment_paths.append((output_file, segment.meta.start, segment.meta.end)) - - return segment_paths diff --git a/tafrigh/cli.py b/tafrigh/cli.py deleted file mode 100644 index 61beb6b..0000000 --- a/tafrigh/cli.py +++ /dev/null @@ -1,264 +0,0 @@ -import csv -import logging -import os -import random -import re -import sys - -from collections import deque -from pathlib import Path -from typing import Any, Generator, Union - -from tqdm import tqdm - -from tafrigh.config import Config -from tafrigh.downloader import Downloader -from tafrigh.utils import cli_utils, file_utils, time_utils -from tafrigh.writer import Writer - - -try: - import requests - - from tafrigh.recognizers.wit_recognizer import WitRecognizer - from tafrigh.utils.wit import file_utils as wit_file_utils -except ModuleNotFoundError: - pass - -try: - from tafrigh.recognizers.whisper_recognizer import WhisperRecognizer - from tafrigh.types.whisper.type_hints import WhisperModel - from tafrigh.utils.whisper import whisper_utils -except ModuleNotFoundError: - pass - - -def main(): - args = cli_utils.parse_args(sys.argv[1:]) - - config = Config( - urls_or_paths=args.urls_or_paths, - skip_if_output_exist=args.skip_if_output_exist, - playlist_items=args.playlist_items, - verbose=args.verbose, - # - model_name_or_path=args.model_name_or_path, - task=args.task, - language=args.language, - use_faster_whisper=args.use_faster_whisper, - beam_size=args.beam_size, - ct2_compute_type=args.ct2_compute_type, - # - wit_client_access_tokens=args.wit_client_access_tokens, - max_cutting_duration=args.max_cutting_duration, - min_words_per_segment=args.min_words_per_segment, - # - save_files_before_compact=args.save_files_before_compact, - save_yt_dlp_responses=args.save_yt_dlp_responses, - output_sample=args.output_sample, - output_formats=args.output_formats, - output_dir=args.output_dir, - ) - - if config.use_wit() and config.input.skip_if_output_exist: - retries = 3 - - while retries > 0: - try: - deque(farrigh(config), maxlen=0) - break - except requests.exceptions.RetryError: - retries -= 1 - else: - deque(farrigh(config), maxlen=0) - - -def farrigh(config: Config) -> Generator[dict[str, int], None, None]: - prepare_output_dir(config.output.output_dir) - - model = None - if not config.use_wit(): - model = whisper_utils.load_model(config.whisper) - - segments = [] - - for idx, item in enumerate(tqdm(config.input.urls_or_paths, desc='URLs or local paths')): - progress_info = { - 'outer_total': len(config.input.urls_or_paths), - 'outer_current': idx + 1, - 'outer_status': 'processing', - } - - if Path(item).exists(): - file_or_folder = Path(item) - for progress_info, local_elements_segments in process_local(file_or_folder, model, config, progress_info): - segments.extend(local_elements_segments) - yield progress_info - elif re.match('(https?://)', item): - for progress_info, url_elements_segments in process_url(item, model, config, progress_info): - segments.extend(url_elements_segments) - yield progress_info - else: - logging.error(f'Path {item} does not exist and is not a URL either.') - - progress_info['outer_status'] = 'completed' - yield progress_info - - continue - - progress_info['outer_status'] = 'completed' - yield progress_info - - write_output_sample(segments, config.output) - - -def prepare_output_dir(output_dir: str) -> None: - os.makedirs(output_dir, exist_ok=True) - - -def process_local( - path: Path, - model: 'WhisperModel', - config: Config, - progress_info: dict, -) -> Generator[tuple[dict[str, int], list[list[dict[str, Union[str, float]]]]], None, None]: - filtered_media_files: list[Path] = file_utils.filter_media_files([path] if path.is_file() else path.iterdir()) - files: list[dict[str, Any]] = [{'file_name': file.name, 'file_path': file} for file in filtered_media_files] - - for idx, file in enumerate(tqdm(files, desc='Local files')): - new_progress_info = progress_info.copy() - new_progress_info.update( - { - 'inner_total': len(files), - 'inner_current': idx + 1, - 'inner_status': 'processing', - 'progress': 0.0, - 'remaining_time': None, - } - ) - yield new_progress_info, [] - - writer = Writer() - if config.input.skip_if_output_exist and writer.is_output_exist(Path(file['file_name']).stem, config.output): - new_progress_info['inner_status'] = 'completed' - yield new_progress_info, [] - - continue - - file_path = str(file['file_path'].absolute()) - - if config.use_wit(): - mp3_file_path = str(wit_file_utils.convert_to_mp3(file['file_path']).absolute()) - recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(mp3_file_path, config.wit) - else: - recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize( - file_path, - model, - config.whisper, - ) - - while True: - try: - new_progress_info.update(next(recognize_generator)) - yield new_progress_info, [] - except StopIteration as exception: - segments = exception.value - break - - if config.use_wit() and file['file_path'].suffix != '.mp3': - Path(mp3_file_path).unlink(missing_ok=True) - - writer.write_all(Path(file['file_name']).stem, segments, config.output) - - for segment in segments: - segment['url'] = f"file://{file_path}&t={int(segment['start'])}" - segment['file_path'] = file_path - - new_progress_info['inner_status'] = 'completed' - new_progress_info['progress'] = 100.0 - yield new_progress_info, writer.compact_segments(segments, config.output.min_words_per_segment) - - -def process_url( - url: str, - model: 'WhisperModel', - config: Config, - progress_info: dict, -) -> Generator[tuple[dict[str, int], list[list[dict[str, Union[str, float]]]]], None, None]: - url_data = Downloader(playlist_items=config.input.playlist_items, output_dir=config.output.output_dir).download( - url, - save_response=config.output.save_yt_dlp_responses, - ) - - if '_type' in url_data and url_data['_type'] == 'playlist': - url_data = url_data['entries'] - else: - url_data = [url_data] - - for idx, element in enumerate(tqdm(url_data, desc='URL elements')): - if not element: - continue - - new_progress_info = progress_info.copy() - new_progress_info.update( - { - 'inner_total': len(url_data), - 'inner_current': idx + 1, - 'inner_status': 'processing', - 'progress': 0.0, - 'remaining_time': None, - } - ) - yield new_progress_info, [] - - writer = Writer() - if config.input.skip_if_output_exist and writer.is_output_exist(element['id'], config.output): - new_progress_info['inner_status'] = 'completed' - yield new_progress_info, [] - - continue - - file_path = os.path.join(config.output.output_dir, f"{element['id']}.mp3") - - if config.use_wit(): - recognize_generator = WitRecognizer(verbose=config.input.verbose).recognize(file_path, config.wit) - else: - recognize_generator = WhisperRecognizer(verbose=config.input.verbose).recognize( - file_path, - model, - config.whisper, - ) - - while True: - try: - new_progress_info.update(next(recognize_generator)) - yield new_progress_info, [] - except StopIteration as exception: - segments = exception.value - break - - writer.write_all(element['id'], segments, config.output) - - for segment in segments: - segment['url'] = f"https://youtube.com/watch?v={element['id']}&t={int(segment['start'])}" - segment['file_path'] = file_path - - new_progress_info['inner_status'] = 'completed' - new_progress_info['progress'] = 100.0 - yield new_progress_info, writer.compact_segments(segments, config.output.min_words_per_segment) - - -def write_output_sample(segments: list[dict[str, Union[str, float]]], output: Config.Output) -> None: - if output.output_sample == 0: - return - - random.shuffle(segments) - - with open(os.path.join(output.output_dir, 'sample.csv'), 'w') as fp: - writer = csv.DictWriter(fp, fieldnames=['start', 'end', 'text', 'url', 'file_path']) - writer.writeheader() - - for segment in segments[: output.output_sample]: - segment['start'] = time_utils.format_timestamp(segment['start'], include_hours=True, decimal_marker=',') - segment['end'] = time_utils.format_timestamp(segment['end'], include_hours=True, decimal_marker=',') - writer.writerow(segment) diff --git a/tafrigh/config.py b/tafrigh/config.py deleted file mode 100644 index 9b1ed0c..0000000 --- a/tafrigh/config.py +++ /dev/null @@ -1,118 +0,0 @@ -import logging - -from tafrigh.types.transcript_type import TranscriptType - - -class Config: - def __init__( - self, - urls_or_paths: list[str], - skip_if_output_exist: bool, - playlist_items: str, - verbose: bool, - model_name_or_path: str, - task: str, - language: str, - use_faster_whisper: bool, - beam_size: int, - ct2_compute_type: str, - wit_client_access_tokens: list[str], - max_cutting_duration: int, - min_words_per_segment: int, - save_files_before_compact: bool, - save_yt_dlp_responses: bool, - output_sample: int, - output_formats: list[str], - output_dir: str, - ): - self.input = self.Input(urls_or_paths, skip_if_output_exist, playlist_items, verbose) - - self.whisper = self.Whisper( - model_name_or_path, - task, - language, - use_faster_whisper, - beam_size, - ct2_compute_type, - ) - - self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration) - - self.output = self.Output( - min_words_per_segment, - save_files_before_compact, - save_yt_dlp_responses, - output_sample, - output_formats, - output_dir, - ) - - def use_wit(self) -> bool: - return self.wit.wit_client_access_tokens is not None and self.wit.wit_client_access_tokens != [] - - class Input: - def __init__(self, urls_or_paths: list[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool): - self.urls_or_paths = urls_or_paths - self.skip_if_output_exist = skip_if_output_exist - self.playlist_items = playlist_items - self.verbose = verbose - - class Whisper: - def __init__( - self, - model_name_or_path: str, - task: str, - language: str, - use_faster_whisper: bool, - beam_size: int, - ct2_compute_type: str, - ): - if model_name_or_path.endswith('.en'): - logging.warn(f'{model_name_or_path} is an English-only model, setting language to English.') - language = 'en' - - self.model_name_or_path = model_name_or_path - self.task = task - self.language = language - self.use_faster_whisper = use_faster_whisper - self.beam_size = beam_size - self.ct2_compute_type = ct2_compute_type - - class Wit: - def __init__(self, wit_client_access_tokens: list[str], max_cutting_duration: int): - if wit_client_access_tokens is None: - self.wit_client_access_tokens = None - else: - self.wit_client_access_tokens = [ - key for key in wit_client_access_tokens if key is not None and key != '' - ] - - self.max_cutting_duration = max_cutting_duration - - class Output: - def __init__( - self, - min_words_per_segment: int, - save_files_before_compact: bool, - save_yt_dlp_responses: bool, - output_sample: int, - output_formats: list[str], - output_dir: str, - ): - if 'all' in output_formats: - output_formats = list(TranscriptType) - else: - output_formats = [TranscriptType(output_format) for output_format in output_formats] - - if TranscriptType.ALL in output_formats: - output_formats.remove(TranscriptType.ALL) - - if TranscriptType.NONE in output_formats: - output_formats.remove(TranscriptType.NONE) - - self.min_words_per_segment = min_words_per_segment - self.save_files_before_compact = save_files_before_compact - self.save_yt_dlp_responses = save_yt_dlp_responses - self.output_sample = output_sample - self.output_formats = output_formats - self.output_dir = output_dir diff --git a/tafrigh/downloader.py b/tafrigh/downloader.py deleted file mode 100644 index a57e290..0000000 --- a/tafrigh/downloader.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -import os - -from typing import Any, Union - -import yt_dlp - - -class Downloader: - def __init__(self, playlist_items: str, output_dir: str): - self.playlist_items = playlist_items - self.output_dir = output_dir - self.youtube_dl_with_archive = yt_dlp.YoutubeDL(self._config(os.path.join(self.output_dir, 'archive.txt'))) - self.youtube_dl_without_archive = yt_dlp.YoutubeDL(self._config(False)) - - def _config(self, download_archive: Union[str, bool]) -> dict[str, Any]: - return { - 'quiet': True, - 'verbose': False, - 'format': 'bestaudio', - 'extract_audio': True, - 'outtmpl': os.path.join(self.output_dir, '%(id)s.%(ext)s'), - 'ignoreerrors': True, - 'download_archive': download_archive, - 'playlist_items': self.playlist_items, - 'postprocessors': [ - { - 'key': 'FFmpegExtractAudio', - 'preferredcodec': 'mp3', - }, - ], - } - - def download(self, url: str, save_response: bool = False) -> dict[str, Any]: - self.youtube_dl_with_archive.download(url) - url_data = self.youtube_dl_without_archive.extract_info(url, download=False) - - if save_response: - self._save_response(url_data) - - return url_data - - def _save_response(self, url_data: dict[str, Any]) -> None: - if '_type' in url_data and url_data['_type'] == 'playlist': - for entry in url_data['entries']: - if entry and 'requested_downloads' in entry: - self._remove_postprocessors(entry['requested_downloads']) - elif 'requested_downloads' in url_data: - self._remove_postprocessors(url_data['requested_downloads']) - - file_path = os.path.join(self.output_dir, f"{url_data['id']}.json") - - with open(file_path, 'w', encoding='utf-8') as fp: - json.dump(url_data, fp, indent=2, ensure_ascii=False) - - def _remove_postprocessors(self, requested_downloads: list[dict[str, Any]]) -> None: - for requested_download in requested_downloads: - requested_download.pop('__postprocessors') diff --git a/tafrigh/recognizers/__init__.py b/tafrigh/recognizers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tafrigh/recognizers/whisper_recognizer.py b/tafrigh/recognizers/whisper_recognizer.py deleted file mode 100644 index 197ab21..0000000 --- a/tafrigh/recognizers/whisper_recognizer.py +++ /dev/null @@ -1,104 +0,0 @@ -import warnings - -from typing import Generator, Union - -import faster_whisper -import whisper - -from tqdm import tqdm - -from tafrigh.config import Config -from tafrigh.types.whisper.type_hints import WhisperModel - - -class WhisperRecognizer: - def __init__(self, verbose: bool): - self.verbose = verbose - - def recognize( - self, - file_path: str, - model: WhisperModel, - whisper_config: Config.Whisper, - ) -> Generator[dict[str, float], None, list[dict[str, Union[str, float]]]]: - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - - if isinstance(model, whisper.Whisper): - whisper_generator = self._recognize_stable_whisper(file_path, model, whisper_config) - elif isinstance(model, faster_whisper.WhisperModel): - whisper_generator = self._recognize_faster_whisper(file_path, model, whisper_config) - - while True: - try: - yield next(whisper_generator) - except StopIteration as e: - return e.value - - def _recognize_stable_whisper( - self, - audio_file_path: str, - model: whisper.Whisper, - whisper_config: Config.Whisper, - ) -> Generator[dict[str, float], None, list[dict[str, Union[str, float]]]]: - yield {'progress': 0.0, 'remaining_time': None} - - segments = model.transcribe( - audio=audio_file_path, - verbose=self.verbose, - task=whisper_config.task, - language=whisper_config.language, - beam_size=whisper_config.beam_size, - ).segments - - return [ - { - 'start': segment.start, - 'end': segment.end, - 'text': segment.text.strip(), - } - for segment in segments - ] - - def _recognize_faster_whisper( - self, - audio_file_path: str, - model: faster_whisper.WhisperModel, - whisper_config: Config.Whisper, - ) -> Generator[dict[str, float], None, list[dict[str, Union[str, float]]]]: - segments, info = model.transcribe( - audio=audio_file_path, - task=whisper_config.task, - language=whisper_config.language, - beam_size=whisper_config.beam_size, - ) - - converted_segments = [] - last_end = 0 - with tqdm( - total=round(info.duration, 2), - unit='sec', - bar_format='{desc}: {percentage:.2f}%|{bar}| {n:.2f}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]', - disable=self.verbose is not False, - ) as pbar: - for segment in segments: - converted_segments.append( - { - 'start': segment.start, - 'end': segment.end, - 'text': segment.text.strip(), - } - ) - - pbar_update = min(segment.end - last_end, info.duration - pbar.n) - pbar.update(pbar_update) - last_end = segment.end - - yield { - 'progress': round(pbar.n / pbar.total * 100, 2), - 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] - if pbar.format_dict['rate'] and pbar.total - else None, - } - - return converted_segments diff --git a/tafrigh/recognizers/wit_calling_throttle.py b/tafrigh/recognizers/wit_calling_throttle.py deleted file mode 100644 index 18dde0c..0000000 --- a/tafrigh/recognizers/wit_calling_throttle.py +++ /dev/null @@ -1,38 +0,0 @@ -import time - -from multiprocessing.managers import BaseManager -from threading import Lock - - -class WitCallingThrottle: - def __init__(self, wit_client_access_tokens_count: int, call_times_limit: int = 1, expired_time: int = 1): - self.wit_client_access_tokens_count = wit_client_access_tokens_count - self.call_times_limit = call_times_limit - self.expired_time = expired_time - self.call_timestamps = [[] for _ in range(self.wit_client_access_tokens_count)] - self.locks = [Lock() for _ in range(self.wit_client_access_tokens_count)] - - def throttle(self, wit_client_access_token_index: int) -> None: - with self.locks[wit_client_access_token_index]: - while len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: - now = time.time() - - self.call_timestamps[wit_client_access_token_index] = list( - filter( - lambda call_timestamp, now=now: now - call_timestamp < self.expired_time, - self.call_timestamps[wit_client_access_token_index], - ) - ) - - if len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: - time_to_sleep = self.call_timestamps[wit_client_access_token_index][0] + self.expired_time - now - time.sleep(time_to_sleep) - - self.call_timestamps[wit_client_access_token_index].append(time.time()) - - -class WitCallingThrottleManager(BaseManager): - pass - - -WitCallingThrottleManager.register('WitCallingThrottle', WitCallingThrottle) diff --git a/tafrigh/recognizers/wit_recognizer.py b/tafrigh/recognizers/wit_recognizer.py deleted file mode 100644 index 2009891..0000000 --- a/tafrigh/recognizers/wit_recognizer.py +++ /dev/null @@ -1,158 +0,0 @@ -import json -import logging -import multiprocessing -import os -import shutil -import tempfile -import time - -from typing import Generator, Union - -import requests - -from requests.adapters import HTTPAdapter -from tqdm import tqdm -from urllib3.util.retry import Retry - -from tafrigh.audio_splitter import AudioSplitter -from tafrigh.config import Config -from tafrigh.recognizers.wit_calling_throttle import WitCallingThrottle, WitCallingThrottleManager - - -def init_pool(throttle: WitCallingThrottle) -> None: - global wit_calling_throttle - - wit_calling_throttle = throttle - - -class WitRecognizer: - def __init__(self, verbose: bool): - self.verbose = verbose - self.processes_per_wit_client_access_token = min(4, multiprocessing.cpu_count()) - - def recognize( - self, - file_path: str, - wit_config: Config.Wit, - ) -> Generator[dict[str, float], None, list[dict[str, Union[str, float]]]]: - temp_directory = tempfile.mkdtemp() - - segments = AudioSplitter().split( - file_path, - temp_directory, - max_dur=wit_config.max_cutting_duration, - expand_segments_with_noise=True, - ) - - retry_strategy = Retry( - total=5, - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=['POST'], - backoff_factor=1, - ) - - adapter = HTTPAdapter(max_retries=retry_strategy) - - session = requests.Session() - session.mount('https://', adapter) - - pool_processes_count = min( - self.processes_per_wit_client_access_token * len(wit_config.wit_client_access_tokens), - multiprocessing.cpu_count(), - ) - - with WitCallingThrottleManager() as manager: - wit_calling_throttle = manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens)) - - with multiprocessing.Pool( - processes=pool_processes_count, - initializer=init_pool, - initargs=(wit_calling_throttle,), - ) as pool: - async_results = [ - pool.apply_async( - self._process_segment, - ( - segment, - file_path, - wit_config, - session, - index % len(wit_config.wit_client_access_tokens), - ), - ) - for index, segment in enumerate(segments) - ] - - transcriptions = [] - - with tqdm(total=len(segments), disable=self.verbose is not False) as pbar: - for async_result in async_results: - async_result.wait() - pbar.update(1) - - transcriptions.append(async_result.get()) - - yield { - 'progress': round(len(transcriptions) / len(segments) * 100, 2), - 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] - if pbar.format_dict['rate'] and pbar.total - else None, - } - - shutil.rmtree(temp_directory) - - return transcriptions - - def _process_segment( - self, - segment: tuple[str, float, float], - file_path: str, - wit_config: Config.Wit, - session: requests.Session, - wit_client_access_token_index: int, - ) -> dict[str, Union[str, float]]: - wit_calling_throttle.throttle(wit_client_access_token_index) - - segment_file_path, start, end = segment - - with open(segment_file_path, 'rb') as mp3_file: - audio_content = mp3_file.read() - - retries = 5 - - text = '' - while retries > 0: - try: - response = session.post( - 'https://api.wit.ai/speech', - headers={ - 'Accept': 'application/vnd.wit.20200513+json', - 'Content-Type': 'audio/mpeg3', - 'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}', - }, - data=audio_content, - ) - - if response.status_code == 200: - text = json.loads(response.text)['text'] - break - else: - retries -= 1 - time.sleep(self.processes_per_wit_client_access_token + 1) - except Exception: - retries -= 1 - time.sleep(self.processes_per_wit_client_access_token + 1) - - if retries == 0: - logging.warn( - f"The segment from `{file_path}` file that starts at {start} and ends at {end}" - " didn't transcribed successfully." - ) - - os.remove(segment_file_path) - - return { - 'start': start, - 'end': end, - 'text': text.strip(), - } diff --git a/tafrigh/types/__init__.py b/tafrigh/types/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tafrigh/types/transcript_type.py b/tafrigh/types/transcript_type.py deleted file mode 100644 index 3cfa105..0000000 --- a/tafrigh/types/transcript_type.py +++ /dev/null @@ -1,15 +0,0 @@ -from enum import Enum - - -class TranscriptType(Enum): - ALL = 'all' - TXT = 'txt' - SRT = 'srt' - VTT = 'vtt' - CSV = 'csv' - TSV = 'tsv' - JSON = 'json' - NONE = 'none' - - def __str__(self): - return self.value diff --git a/tafrigh/types/whisper/__init__.py b/tafrigh/types/whisper/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tafrigh/types/whisper/type_hints.py b/tafrigh/types/whisper/type_hints.py deleted file mode 100644 index cd13ed9..0000000 --- a/tafrigh/types/whisper/type_hints.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import TypeVar - -import faster_whisper -import whisper - - -WhisperModel = TypeVar( - 'WhisperModel', - whisper.Whisper, - faster_whisper.WhisperModel, -) diff --git a/tafrigh/utils/__init__.py b/tafrigh/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tafrigh/utils/cli_utils.py b/tafrigh/utils/cli_utils.py deleted file mode 100644 index 74e0e24..0000000 --- a/tafrigh/utils/cli_utils.py +++ /dev/null @@ -1,200 +0,0 @@ -import argparse -import importlib.metadata -import re - -from tafrigh.types.transcript_type import TranscriptType - - -PLAYLIST_ITEMS_RE = re.compile( - r'''(?x) - (?P[+-]?\d+)? - (?P[:-] - (?P[+-]?\d+|inf(?:inite)?)? - (?::(?P[+-]?\d+))? - )?''' -) - - -def parse_args(argv: list[str]) -> argparse.Namespace: - parser = argparse.ArgumentParser() - - parser.add_argument( - '--version', - action='version', - version=importlib.metadata.version('tafrigh'), - ) - - input_group = parser.add_argument_group('Input') - - input_group.add_argument( - 'urls_or_paths', - nargs='+', - help='Video/Playlist URLs or local folder/file(s) to transcribe.', - ) - - input_group.add_argument( - '--skip_if_output_exist', - action=argparse.BooleanOptionalAction, - default=False, - help='Whether to skip generating the output if the output file already exists.', - ) - - input_group.add_argument( - '--playlist_items', - type=parse_playlist_items, - help='Comma separated playlist_index of the items to download. You can specify a range using "[START]:[STOP][:STEP]".', - ) - - input_group.add_argument( - '--verbose', - action=argparse.BooleanOptionalAction, - default=False, - help='Whether to print out the progress and debug messages.', - ) - - whisper_group = parser.add_argument_group('Whisper') - - whisper_group.add_argument( - '-m', - '--model_name_or_path', - default='small', - help='Name or path of the Whisper model to use.', - ) - - whisper_group.add_argument( - '-t', - '--task', - default='transcribe', - choices=[ - 'transcribe', - 'translate', - ], - help="Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate').", - ) - - whisper_group.add_argument( - '-l', - '--language', - default=None, - choices=['af', 'am', 'ar', 'as', 'az', 'ba', 'be', 'bg', 'bn', 'bo', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de'] - + ['el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'gl', 'gu', 'ha', 'haw', 'he', 'hi', 'hr', 'ht', 'hu'] - + ['hy', 'id', 'is', 'it', 'ja', 'jw', 'ka', 'kk', 'km', 'kn', 'ko', 'la', 'lb', 'ln', 'lo', 'lt', 'lv', 'mg'] - + ['mi', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'my', 'ne', 'nl', 'nn', 'no', 'oc', 'pa', 'pl', 'ps', 'pt', 'ro'] - + ['ru', 'sa', 'sd', 'si', 'sk', 'sl', 'sn', 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk'] - + ['tl', 'tr', 'tt', 'uk', 'ur', 'uz', 'vi', 'yi', 'yo', 'zh'], - help='Language spoken in the audio, skip to perform language detection.', - ) - - whisper_group.add_argument( - '--use_faster_whisper', - action=argparse.BooleanOptionalAction, - default=False, - help='Whether to use Faster Whisper implementation.', - ) - - whisper_group.add_argument( - '--beam_size', - type=int, - default=5, - help='Number of beams in beam search, only applicable when temperature is zero.', - ) - - whisper_group.add_argument( - '--ct2_compute_type', - default='default', - choices=[ - 'default', - 'int8', - 'int8_float16', - 'int16', - 'float16', - ], - help='Quantization type applied while converting the model to CTranslate2 format.', - ) - - wit_group = parser.add_argument_group('Wit') - - wit_group.add_argument( - '-w', - '--wit_client_access_tokens', - nargs='+', - help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.', - ) - - wit_group.add_argument( - '--max_cutting_duration', - type=int, - default=15, - choices=range(1, 17), - metavar='[1-17]', - help='The maximum allowed cutting duration. It should be between 1 and 17.', - ) - - output_group = parser.add_argument_group('Output') - - output_group.add_argument( - '--min_words_per_segment', - type=int, - default=1, - help='The minimum number of words should appear in each transcript segment. Any segment have words count less than this threshold will be merged with the next one. Pass 0 to disable this behavior.', - ) - - output_group.add_argument( - '--save_files_before_compact', - action=argparse.BooleanOptionalAction, - default=False, - help='Saves the output files before applying the compact logic that is based on --min_words_per_segment.', - ) - - output_group.add_argument( - '--save_yt_dlp_responses', - action=argparse.BooleanOptionalAction, - default=False, - help='Whether to save the yt-dlp library JSON responses or not.', - ) - - output_group.add_argument( - '--output_sample', - type=int, - default=0, - help='Samples random compacted segments from the output and generates a CSV file contains the sampled data. Pass 0 to disable this behavior.', - ) - - output_group.add_argument( - '-f', - '--output_formats', - nargs='+', - default='all', - choices=[transcript_type.value for transcript_type in TranscriptType], - help='Format of the output file; if not specified, all available formats will be produced.', - ) - - output_group.add_argument('-o', '--output_dir', default='.', help='Directory to save the outputs.') - - return parser.parse_args(argv) - - -def parse_playlist_items(arg_value: str) -> str: - for segment in arg_value.split(','): - if not segment: - raise ValueError('There is two or more consecutive commas.') - - mobj = PLAYLIST_ITEMS_RE.fullmatch(segment) - if not mobj: - raise ValueError(f'{segment!r} is not a valid specification.') - - _, _, step, _ = mobj.group('start', 'end', 'step', 'range') - if int_or_none(step) == 0: - raise ValueError(f'Step in {segment!r} cannot be zero.') - - return arg_value - - -def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1): - if get_attr and v is not None: - v = getattr(v, get_attr, None) - - try: - return int(v) * invscale // scale - except (ValueError, TypeError, OverflowError): - return default diff --git a/tafrigh/utils/file_utils.py b/tafrigh/utils/file_utils.py deleted file mode 100644 index 00f0078..0000000 --- a/tafrigh/utils/file_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import mimetypes - -from pathlib import Path - - -mimetypes.init() - - -def filter_media_files(paths: list[Path]) -> list[Path]: - # Filter out non audio or video files - filtered_media_files: list[str] = [] - for path in paths: - mime = mimetypes.guess_type(path)[0] - if mime is None: - continue - mime_type = mime.split('/')[0] - if mime_type not in ('audio', 'video'): - continue - filtered_media_files.append(path) - return filtered_media_files diff --git a/tafrigh/utils/time_utils.py b/tafrigh/utils/time_utils.py deleted file mode 100644 index 4f113a4..0000000 --- a/tafrigh/utils/time_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -def format_timestamp(seconds: float, include_hours: bool = False, decimal_marker: str = '.') -> str: - assert seconds >= 0, 'Non-negative timestamp expected' - - total_milliseconds = int(round(seconds * 1_000)) - - hours, total_milliseconds = divmod(total_milliseconds, 3_600_000) - minutes, total_milliseconds = divmod(total_milliseconds, 60_000) - seconds, milliseconds = divmod(total_milliseconds, 1_000) - - if include_hours or hours > 0: - time_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" - else: - time_str = f"{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" - - return time_str diff --git a/tafrigh/utils/whisper/__init__.py b/tafrigh/utils/whisper/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tafrigh/utils/whisper/whisper_utils.py b/tafrigh/utils/whisper/whisper_utils.py deleted file mode 100644 index 316b568..0000000 --- a/tafrigh/utils/whisper/whisper_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import faster_whisper -import stable_whisper - -from tafrigh.config import Config -from tafrigh.types.whisper.type_hints import WhisperModel - - -def load_model(whisper_config: Config.Whisper) -> WhisperModel: - if whisper_config.use_faster_whisper: - return faster_whisper.WhisperModel( - whisper_config.model_name_or_path, - compute_type=whisper_config.ct2_compute_type, - ) - else: - return stable_whisper.load_model(whisper_config.model_name_or_path) diff --git a/tafrigh/utils/wit/__init__.py b/tafrigh/utils/wit/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tafrigh/utils/wit/file_utils.py b/tafrigh/utils/wit/file_utils.py deleted file mode 100644 index d087d3d..0000000 --- a/tafrigh/utils/wit/file_utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from pathlib import Path - -from pydub import AudioSegment - - -def convert_to_mp3(file: Path) -> Path: - audio_file = AudioSegment.from_file(str(file)) - converted_file_path = file.with_suffix('.mp3') - audio_file.export(str(converted_file_path), format='mp3') - return converted_file_path diff --git a/tafrigh/writer.py b/tafrigh/writer.py deleted file mode 100644 index 47ed889..0000000 --- a/tafrigh/writer.py +++ /dev/null @@ -1,162 +0,0 @@ -import csv -import json -import os - -from pathlib import Path -from typing import Union - -from tafrigh.config import Config -from tafrigh.types.transcript_type import TranscriptType -from tafrigh.utils import time_utils - - -class Writer: - def write_all( - self, - file_name: str, - segments: list[dict[str, Union[str, float]]], - output_config: Config.Output, - ) -> None: - if output_config.save_files_before_compact: - for output_format in output_config.output_formats: - self.write( - output_format, - os.path.join(output_config.output_dir, f'{file_name}-original.{output_format}'), - segments, - ) - - if not output_config.save_files_before_compact or output_config.min_words_per_segment != 0: - compacted_segments = self.compact_segments(segments, output_config.min_words_per_segment) - - for output_format in output_config.output_formats: - self.write( - output_format, - os.path.join(output_config.output_dir, f'{file_name}.{output_format}'), - compacted_segments, - ) - - def write( - self, - format: TranscriptType, - file_path: str, - segments: list[dict[str, Union[str, float]]], - ) -> None: - if format == TranscriptType.TXT: - self.write_txt(file_path, segments) - elif format == TranscriptType.SRT: - self.write_srt(file_path, segments) - elif format == TranscriptType.VTT: - self.write_vtt(file_path, segments) - elif format == TranscriptType.CSV: - self.write_csv(file_path, segments) - elif format == TranscriptType.TSV: - self.write_csv(file_path, segments, '\t') - elif format == TranscriptType.JSON: - self.write_json(file_path, segments) - - def write_txt( - self, - file_path: str, - segments: list[dict[str, Union[str, float]]], - ) -> None: - self._write_to_file(file_path, self.generate_txt(segments)) - - def write_srt( - self, - file_path: str, - segments: list[dict[str, Union[str, float]]], - ) -> None: - self._write_to_file(file_path, self.generate_srt(segments)) - - def write_vtt( - self, - file_path: str, - segments: list[dict[str, Union[str, float]]], - ) -> None: - self._write_to_file(file_path, self.generate_vtt(segments)) - - def write_csv( - self, - file_path: str, - segments: list[dict[str, Union[str, float]]], - delimiter=',', - ) -> None: - with open(file_path, 'w', encoding='utf-8') as fp: - writer = csv.DictWriter(fp, fieldnames=['text', 'start', 'end'], delimiter=delimiter) - writer.writeheader() - writer.writerows(segments) - - def write_json( - self, - file_path: str, - segments: list[dict[str, Union[str, float]]], - ) -> None: - with open(file_path, 'w', encoding='utf-8') as fp: - json.dump(segments, fp, ensure_ascii=False, indent=2) - - def generate_txt(self, segments: list[dict[str, Union[str, float]]]) -> str: - return '\n'.join(list(map(lambda segment: segment['text'].strip(), segments))) + '\n' - - def generate_srt(self, segments: list[dict[str, Union[str, float]]]) -> str: - return ''.join( - f"{i}\n" - f"{time_utils.format_timestamp(segment['start'], include_hours=True, decimal_marker=',')} --> " - f"{time_utils.format_timestamp(segment['end'], include_hours=True, decimal_marker=',')}\n" - f"{segment['text'].strip()}\n\n" - for i, segment in enumerate(segments, start=1) - ) - - def generate_vtt(self, segments: list[dict[str, Union[str, float]]]) -> str: - return 'WEBVTT\n\n' + ''.join( - f"{time_utils.format_timestamp(segment['start'])} --> {time_utils.format_timestamp(segment['end'])}\n" - f"{segment['text'].strip()}\n\n" - for segment in segments - ) - - def compact_segments( - self, - segments: list[dict[str, Union[str, float]]], - min_words_per_segment: int, - ) -> list[dict[str, Union[str, float]]]: - if min_words_per_segment == 0: - return segments - - compacted_segments = [] - tmp_segment = None - - for segment in segments: - if tmp_segment: - tmp_segment['text'] += f" {segment['text'].strip()}" - tmp_segment['end'] = segment['end'] - - if len(tmp_segment['text'].split()) >= min_words_per_segment: - compacted_segments.append(tmp_segment) - tmp_segment = None - elif len(segment['text'].split()) < min_words_per_segment: - tmp_segment = dict(segment) - elif len(segment['text'].split()) >= min_words_per_segment: - compacted_segments.append(dict(segment)) - - if tmp_segment: - compacted_segments.append(tmp_segment) - - return compacted_segments - - def is_output_exist(self, file_name: str, output_config: Config.Output): - if output_config.save_files_before_compact and not all( - Path(output_config.output_dir, f'{file_name}-original.{output_format}').is_file() - for output_format in output_config.output_formats - ): - return False - - if (not output_config.save_files_before_compact or output_config.min_words_per_segment != 0) and not all( - Path(output_config.output_dir, f'{file_name}.{output_format}').is_file() - for output_format in output_config.output_formats - ): - return False - - return True - - def _write_to_file(self, file_path: str, content: str) -> None: - with open(file_path, 'w', encoding='utf-8') as fp: - fp.write(content)