diff --git a/.github/workflows/formatter.yml b/.github/workflows/formatter.yml index 6522fcd..90c9231 100644 --- a/.github/workflows/formatter.yml +++ b/.github/workflows/formatter.yml @@ -22,4 +22,4 @@ jobs: - name: isort formatter uses: isort/isort-action@v1 with: - configuration: "--profile black --check-only --diff --line-length 120" + configuration: "--profile black --src tafrigh --line-length 120 --lines-between-types 1 --lines-after-imports 2 --case-sensitive --trailing-comma --check-only --diff" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c9147a6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,42 @@ +--- +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-added-large-files + - id: check-ast + - id: check-builtin-literals + - id: check-case-conflict + - id: check-docstring-first + - id: check-json + - id: check-merge-conflict + - id: check-shebang-scripts-are-executable + - id: check-symlinks + - id: check-toml + - id: check-vcs-permalinks + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: destroyed-symlinks + - id: detect-private-key + - id: end-of-file-fixer + types: [ python ] + - id: fix-byte-order-marker + - id: mixed-line-ending + - id: name-tests-test + args: [ --pytest-test-first ] + - id: trailing-whitespace + types: [ python ] + + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + additional_dependencies: [click==8.0.4] + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort diff --git a/README.md b/README.md index 53ad5d8..6408d2a 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@
--wit_client_access_token
. إذا تم تمرير هذا الاختيار، سيتم استخدام wit.ai لتفريغ المواد إلى نصوص. غير ذلك، سيتم استخدام نماذج Whisper--wit_client_access_tokens
. إذا تم تمرير هذا الاختيار، سيتم استخدام wit.ai لتفريغ المواد إلى نصوص. غير ذلك، سيتم استخدام نماذج Whisper--max_cutting_duration
. القيمة الافتراضية هي 15
مرحبًا بك في تفريغ لتفريغ المواد الصوتية والمرئية باستخدام تقنيات الذكاء الاصطناعي. لاستخدام تفريغ:
\n", - "\n", - "يمكنك تجربة التفريغ باستخدام نماذج Whisper وتقنية wit.ai واستخدام التفريغ الأفضل لحالتك. كملاحظة عامة، نماذج Whisper تقوم بتفريغ الهمزات والتشكيلات وعلامات الترقيم بشكل أفضل من wit.ai، لكن wit.ai يُنتج أخطاء إملائية أقل.
\n", - "\n", - "عندما ينتهي التحويل سيتم تنزيل الملفات النصية بشكل تلقائي بصيغة txt
و srt
وسيكون اسم الملف هو مُعرّف المادة على منصة YouTube الذي يكون في آخر رابط المادة: https://youtu.be/4h5P7jXvW98.
يمكنك متابعة مشروع الكتب المُيسّرة والتواصل معنا من خلال:
\n", - "\n", - "مرحبًا بك في تفريغ لتفريغ المواد الصوتية والمرئية باستخدام تقنيات الذكاء الاصطناعي. لاستخدام تفريغ:
\n", + "\n", + "يمكنك تجربة التفريغ باستخدام نماذج Whisper وتقنية wit.ai واستخدام التفريغ الأفضل لحالتك. كملاحظة عامة، نماذج Whisper تقوم بتفريغ الهمزات والتشكيلات وعلامات الترقيم بشكل أفضل من wit.ai، لكن wit.ai يُنتج أخطاء إملائية أقل.
\n", + "\n", + "عندما ينتهي التحويل سيتم تنزيل الملفات النصية بشكل تلقائي بصيغة txt
و srt
وسيكون اسم الملف هو مُعرّف المادة على منصة YouTube الذي يكون في آخر رابط المادة: https://youtu.be/4h5P7jXvW98.
يمكنك متابعة مشروع الكتب المُيسّرة والتواصل معنا من خلال:
\n", + "\n", + "روابط المواد المطلوب تفريغها وتأكد من فصلها بمسافة، أو أترك الحقل فارغًا لتفريغ المواد التي قمت برفعها
\n", + "urls = 'https://youtu.be/4h5P7jXvW98 https://youtu.be/jpfndVSROpw' # @param { type: \"string\" }\n", + "\n", + "# @markdownأقل عدد من الكلمات في كل جزء من أجزاء التفريغ
\n", + "min_words_per_segment = 1 # @param {type:\"slider\", min:1, max:100, step:1}\n", + "\n", + "# @markdown ---\n", + "\n", + "# @markdownالنموذج المُراد استخدامه للتفريغ
\n", + "model = 'large-v2 (\\u0623\\u0641\\u0636\\u0644 \\u062F\\u0642\\u0629)' # @param [\"large-v2 (أفضل دقة)\", \"medium\", \"base\", \"small\", \"tiny (أقل دقة)\"]\n", + "\n", + "# @markdown(اختياري) لغة المادة
\n", + "language = 'ar' # @param [\"ar\", \"af\", \"am\", \"as\", \"az\", \"ba\", \"be\", \"bg\", \"bn\", \"bo\", \"br\", \"bs\", \"ca\", \"cs\", \"cy\", \"da\", \"de\", \"el\", \"en\", \"es\", \"et\", \"eu\", \"fa\", \"fi\", \"fo\", \"fr\", \"gl\", \"gu\", \"ha\", \"haw\", \"he\", \"hi\", \"hr\", \"ht\", \"hu\", \"hy\", \"id\", \"is\", \"it\", \"ja\", \"jw\", \"ka\", \"kk\", \"km\", \"kn\", \"ko\", \"la\", \"lb\", \"ln\", \"lo\", \"lt\", \"lv\", \"mg\", \"mi\", \"mk\", \"ml\", \"mn\", \"mr\", \"ms\", \"mt\", \"my\", \"ne\", \"nl\", \"nn\", \"no\", \"oc\", \"pa\", \"pl\", \"ps\", \"pt\", \"ro\", \"ru\", \"sa\", \"sd\", \"si\", \"sk\", \"sl\", \"sn\", \"so\", \"sq\", \"sr\", \"su\", \"sv\", \"sw\", \"ta\", \"te\", \"tg\", \"th\", \"tk\", \"tl\", \"tr\", \"tt\", \"uk\", \"ur\", \"uz\", \"vi\", \"yi\", \"yo\", \"zh\"]\n", + "\n", + "# @markdown ---\n", + "\n", + "# @markdownالمفتاح الخاص بك على موقع wit.ai
\n", + "wit_api_key = '' # @param { type: \"string\" }\n", + "\n", + "# @markdown(اختياري) أقصى مدة للتقطيع والتي ستؤثر على طول الجمل في ملف SRT
\n", + "max_cutting_duration = 15 # @param {type:\"slider\", min:1, max:17, step:1}\n", + "\n", + "if model == 'large-v2 (\\u0623\\u0641\\u0636\\u0644 \\u062F\\u0642\\u0629)':\n", + " model = 'large-v2'\n", + "elif model == 'tiny (\\u0623\\u0642\\u0644 \\u062F\\u0642\\u0629)':\n", + " model = 'tiny'\n", + "\n", + "# Imports.\n", + "import glob\n", + "import os\n", + "\n", + "from collections import deque\n", + "\n", + "from google.colab import files\n", + "from tafrigh import farrigh, Config\n", + "\n", + "# Setup directories.\n", + "output_dir = 'output'\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + "\n", + "# Start Tafrigh.\n", + "if wit_api_key:\n", + " print('جارٍ تحويل المواد إلى نصوص باستخدام تقنيات wit.ai.')\n", + "else:\n", + " print('جارٍ تحويل المواد إلى نصوص باستخدام نماذج Whisper.')\n", + "\n", + "config = Config(\n", + " urls_or_paths=list(map(str.strip, urls.split(' '))) if len(urls.strip()) else ['.'],\n", + " skip_if_output_exist=False,\n", + " playlist_items='',\n", + " verbose=False,\n", + " model_name_or_path=model,\n", + " task='transcribe',\n", + " language=language,\n", + " use_whisper_jax=False,\n", + " use_faster_whisper=True,\n", + " beam_size=5,\n", + " ct2_compute_type='default',\n", + " wit_client_access_tokens=[wit_api_key],\n", + " max_cutting_duration=max_cutting_duration,\n", + " min_words_per_segment=min_words_per_segment,\n", + " save_files_before_compact=False,\n", + " save_yt_dlp_responses=False,\n", + " output_sample=0,\n", + " output_formats=['txt', 'srt'],\n", + " output_dir=output_dir,\n", + ")\n", + "\n", + "deque(farrigh(config), maxlen=0)\n", + "\n", + "# Download all txt and srt files.\n", + "print('جارٍ تنزيل الملفات النصية.')\n", + "\n", + "txt_files = glob.glob(f\"{output_dir}/*.txt\")\n", + "srt_files = glob.glob(f\"{output_dir}/*.srt\")\n", + "\n", + "try:\n", + " txt_files.remove('output/archive.txt')\n", + "except ValueError:\n", + " pass\n", + "\n", + "for file in txt_files + srt_files:\n", + " files.download(file)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.12" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "uY05i198xi3D" - }, - "outputs": [], - "source": [ - "# @titleروابط المواد المطلوب تفريغها وتأكد من فصلها بمسافة، أو أترك الحقل فارغًا لتفريغ المواد التي قمت برفعها
\n", - "urls = 'https://youtu.be/4h5P7jXvW98 https://youtu.be/jpfndVSROpw' # @param { type: \"string\" }\n", - "\n", - "# @markdownأقل عدد من الكلمات في كل جزء من أجزاء التفريغ
\n", - "min_words_per_segment = 1 # @param {type:\"slider\", min:1, max:100, step:1}\n", - "\n", - "# @markdown ---\n", - "\n", - "# @markdownالنموذج المُراد استخدامه للتفريغ
\n", - "model = 'large-v2 (\\u0623\\u0641\\u0636\\u0644 \\u062F\\u0642\\u0629)' # @param [\"large-v2 (أفضل دقة)\", \"medium\", \"base\", \"small\", \"tiny (أقل دقة)\"]\n", - "\n", - "# @markdown(اختياري) لغة المادة
\n", - "language = 'ar' # @param [\"ar\", \"af\", \"am\", \"as\", \"az\", \"ba\", \"be\", \"bg\", \"bn\", \"bo\", \"br\", \"bs\", \"ca\", \"cs\", \"cy\", \"da\", \"de\", \"el\", \"en\", \"es\", \"et\", \"eu\", \"fa\", \"fi\", \"fo\", \"fr\", \"gl\", \"gu\", \"ha\", \"haw\", \"he\", \"hi\", \"hr\", \"ht\", \"hu\", \"hy\", \"id\", \"is\", \"it\", \"ja\", \"jw\", \"ka\", \"kk\", \"km\", \"kn\", \"ko\", \"la\", \"lb\", \"ln\", \"lo\", \"lt\", \"lv\", \"mg\", \"mi\", \"mk\", \"ml\", \"mn\", \"mr\", \"ms\", \"mt\", \"my\", \"ne\", \"nl\", \"nn\", \"no\", \"oc\", \"pa\", \"pl\", \"ps\", \"pt\", \"ro\", \"ru\", \"sa\", \"sd\", \"si\", \"sk\", \"sl\", \"sn\", \"so\", \"sq\", \"sr\", \"su\", \"sv\", \"sw\", \"ta\", \"te\", \"tg\", \"th\", \"tk\", \"tl\", \"tr\", \"tt\", \"uk\", \"ur\", \"uz\", \"vi\", \"yi\", \"yo\", \"zh\"]\n", - "\n", - "# @markdown ---\n", - "\n", - "# @markdownالمفتاح الخاص بك على موقع wit.ai
\n", - "wit_api_key = '' # @param { type: \"string\" }\n", - "\n", - "# @markdown(اختياري) أقصى مدة للتقطيع والتي ستؤثر على طول الجمل في ملف SRT
\n", - "max_cutting_duration = 15 # @param {type:\"slider\", min:1, max:17, step:1}\n", - "\n", - "if model == 'large-v2 (\\u0623\\u0641\\u0636\\u0644 \\u062F\\u0642\\u0629)':\n", - " model = 'large-v2'\n", - "elif model == 'tiny (\\u0623\\u0642\\u0644 \\u062F\\u0642\\u0629)':\n", - " model = 'tiny'\n", - "\n", - "# Imports.\n", - "import glob\n", - "import os\n", - "\n", - "from collections import deque\n", - "\n", - "from google.colab import files\n", - "from tafrigh import farrigh, Config\n", - "\n", - "# Setup directories.\n", - "output_dir = 'output'\n", - "if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - "\n", - "# Start Tafrigh.\n", - "if wit_api_key:\n", - " print('جارٍ تحويل المواد إلى نصوص باستخدام تقنيات wit.ai.')\n", - "else:\n", - " print('جارٍ تحويل المواد إلى نصوص باستخدام نماذج Whisper.')\n", - "\n", - "config = Config(\n", - " urls_or_paths=list(map(str.strip, urls.split(' '))) if len(urls.strip()) else ['.'],\n", - " skip_if_output_exist=False,\n", - " playlist_items='',\n", - " verbose=False,\n", - " model_name_or_path=model,\n", - " task='transcribe',\n", - " language=language,\n", - " use_whisper_jax=False,\n", - " use_faster_whisper=True,\n", - " beam_size=5,\n", - " ct2_compute_type='default',\n", - " wit_client_access_token=wit_api_key,\n", - " max_cutting_duration=max_cutting_duration,\n", - " min_words_per_segment=min_words_per_segment,\n", - " save_files_before_compact=False,\n", - " save_yt_dlp_responses=False,\n", - " output_sample=0,\n", - " output_formats=['txt', 'srt'],\n", - " output_dir=output_dir,\n", - ")\n", - "\n", - "deque(farrigh(config), maxlen=0)\n", - "\n", - "# Download all txt and srt files.\n", - "print('جارٍ تنزيل الملفات النصية.')\n", - "\n", - "txt_files = glob.glob(f\"{output_dir}/*.txt\")\n", - "srt_files = glob.glob(f\"{output_dir}/*.srt\")\n", - "\n", - "try:\n", - " txt_files.remove('output/archive.txt')\n", - "except ValueError:\n", - " pass\n", - "\n", - "for file in txt_files + srt_files:\n", - " files.download(file)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/pyproject.toml b/pyproject.toml index 315006e..3fac9e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,3 +43,16 @@ repository = "https://github.com/ieasybooks/tafrigh" [project.scripts] tafrigh = "tafrigh.cli:main" + +[tool.black] +line-length = 120 +skip-string-normalization = true + +[tool.isort] +profile = "black" +src_paths = "tafrigh" +line_length = 120 +lines_between_types = 1 +lines_after_imports = 2 +case_sensitive = true +include_trailing_comma = true diff --git a/tafrigh/__init__.py b/tafrigh/__init__.py index eafe85d..c47c637 100644 --- a/tafrigh/__init__.py +++ b/tafrigh/__init__.py @@ -4,6 +4,7 @@ from tafrigh.types.transcript_type import TranscriptType from tafrigh.writer import Writer + try: from tafrigh.recognizers.whisper_recognizer import WhisperRecognizer except ModuleNotFoundError: diff --git a/tafrigh/audio_splitter.py b/tafrigh/audio_splitter.py index 5a1c950..be3bc1d 100644 --- a/tafrigh/audio_splitter.py +++ b/tafrigh/audio_splitter.py @@ -2,6 +2,7 @@ import tempfile import numpy as np + from auditok.core import split from scipy.io import wavfile diff --git a/tafrigh/cli.py b/tafrigh/cli.py index 8437b73..05adfb2 100644 --- a/tafrigh/cli.py +++ b/tafrigh/cli.py @@ -4,6 +4,7 @@ import random import re import sys + from collections import deque from pathlib import Path from typing import Any, Generator, Union @@ -15,6 +16,7 @@ from tafrigh.utils import cli_utils, file_utils, time_utils from tafrigh.writer import Writer + try: import requests @@ -48,7 +50,7 @@ def main(): beam_size=args.beam_size, ct2_compute_type=args.ct2_compute_type, # - wit_client_access_token=args.wit_client_access_token, + wit_client_access_tokens=args.wit_client_access_tokens, max_cutting_duration=args.max_cutting_duration, min_words_per_segment=args.min_words_per_segment, # diff --git a/tafrigh/config.py b/tafrigh/config.py index af12216..7a931ea 100644 --- a/tafrigh/config.py +++ b/tafrigh/config.py @@ -17,7 +17,7 @@ def __init__( use_whisper_jax: bool, beam_size: int, ct2_compute_type: str, - wit_client_access_token: str, + wit_client_access_tokens: list[str], max_cutting_duration: int, min_words_per_segment: int, save_files_before_compact: bool, @@ -38,7 +38,7 @@ def __init__( ct2_compute_type, ) - self.wit = self.Wit(wit_client_access_token, max_cutting_duration) + self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration) self.output = self.Output( min_words_per_segment, @@ -50,7 +50,7 @@ def __init__( ) def use_wit(self) -> bool: - return self.wit.wit_client_access_token != '' + return self.wit.wit_client_access_tokens != [] class Input: def __init__(self, urls_or_paths: list[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool): @@ -83,8 +83,8 @@ def __init__( self.ct2_compute_type = ct2_compute_type class Wit: - def __init__(self, wit_client_access_token: str, max_cutting_duration: int): - self.wit_client_access_token = wit_client_access_token + def __init__(self, wit_client_access_tokens: list[str], max_cutting_duration: int): + self.wit_client_access_tokens = wit_client_access_tokens self.max_cutting_duration = max_cutting_duration class Output: diff --git a/tafrigh/downloader.py b/tafrigh/downloader.py index cb6a966..8ea80f0 100644 --- a/tafrigh/downloader.py +++ b/tafrigh/downloader.py @@ -1,5 +1,6 @@ import json import os + from typing import Any, Union import yt_dlp diff --git a/tafrigh/recognizers/whisper_recognizer.py b/tafrigh/recognizers/whisper_recognizer.py index 320a561..dd73828 100644 --- a/tafrigh/recognizers/whisper_recognizer.py +++ b/tafrigh/recognizers/whisper_recognizer.py @@ -1,9 +1,11 @@ import warnings + from typing import Generator, Union import faster_whisper import whisper import whisper_jax + from tqdm import tqdm from tafrigh.config import Config diff --git a/tafrigh/recognizers/wit_calling_throttle.py b/tafrigh/recognizers/wit_calling_throttle.py new file mode 100644 index 0000000..0e83304 --- /dev/null +++ b/tafrigh/recognizers/wit_calling_throttle.py @@ -0,0 +1,38 @@ +import time + +from multiprocessing.managers import BaseManager +from threading import Lock + + +class WitCallingThrottle: + def __init__(self, wit_client_access_tokens_count: int, call_times_limit: int = 1, expired_time: int = 1): + self.wit_client_access_tokens_count = wit_client_access_tokens_count + self.call_times_limit = call_times_limit + self.expired_time = expired_time + self.call_timestamps = [[] for _ in range(self.wit_client_access_tokens_count)] + self.locks = [Lock() for _ in range(self.wit_client_access_tokens_count)] + + def throttle(self, wit_client_access_token_index: int) -> None: + with self.locks[wit_client_access_token_index]: + while len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: + now = time.time() + + self.call_timestamps[wit_client_access_token_index] = list( + filter( + lambda x: now - x < self.expired_time, + self.call_timestamps[wit_client_access_token_index], + ) + ) + + if len(self.call_timestamps[wit_client_access_token_index]) == self.call_times_limit: + time_to_sleep = self.call_timestamps[wit_client_access_token_index][0] + self.expired_time - now + time.sleep(time_to_sleep) + + self.call_timestamps[wit_client_access_token_index].append(time.time()) + + +class WitCallingThrottleManager(BaseManager): + pass + + +WitCallingThrottleManager.register('WitCallingThrottle', WitCallingThrottle) diff --git a/tafrigh/recognizers/wit_recognizer.py b/tafrigh/recognizers/wit_recognizer.py index 643c609..866f33e 100644 --- a/tafrigh/recognizers/wit_recognizer.py +++ b/tafrigh/recognizers/wit_recognizer.py @@ -5,21 +5,30 @@ import shutil import tempfile import time + from typing import Generator, Union import requests + from requests.adapters import HTTPAdapter from tqdm import tqdm from urllib3.util.retry import Retry from tafrigh.audio_splitter import AudioSplitter from tafrigh.config import Config -from tafrigh.utils.decorators import minimum_execution_time +from tafrigh.recognizers.wit_calling_throttle import WitCallingThrottle, WitCallingThrottleManager + + +def init_pool(throttle: WitCallingThrottle) -> None: + global wit_calling_throttle + + wit_calling_throttle = throttle class WitRecognizer: def __init__(self, verbose: bool): self.verbose = verbose + self.processes_per_wit_client_access_token = min(4, multiprocessing.cpu_count()) def recognize( self, @@ -47,41 +56,62 @@ def recognize( session = requests.Session() session.mount('https://', adapter) - with multiprocessing.Pool(processes=min(4, multiprocessing.cpu_count() - 1)) as pool: - async_results = [ - pool.apply_async(self._process_segment, (segment, file_path, wit_config, session)) - for segment in segments - ] - - transcriptions = [] - - with tqdm(total=len(segments), disable=self.verbose is not False) as pbar: - while async_results: - if async_results[0].ready(): - transcriptions.append(async_results.pop(0).get()) - pbar.update(1) - - yield { - 'progress': round(len(transcriptions) / len(segments) * 100, 2), - 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] - if pbar.format_dict['rate'] and pbar.total - else None, - } + pool_processes_count = min( + self.processes_per_wit_client_access_token * len(wit_config.wit_client_access_tokens), + multiprocessing.cpu_count(), + ) - time.sleep(0.5) + with WitCallingThrottleManager() as manager: + wit_calling_throttle = manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens)) + + with multiprocessing.Pool( + processes=pool_processes_count, + initializer=init_pool, + initargs=(wit_calling_throttle,), + ) as pool: + async_results = [ + pool.apply_async( + self._process_segment, + ( + segment, + file_path, + wit_config, + session, + index % len(wit_config.wit_client_access_tokens), + ), + ) + for index, segment in enumerate(segments) + ] + + transcriptions = [] + + with tqdm(total=len(segments), disable=self.verbose is not False) as pbar: + while async_results: + if async_results[0].ready(): + transcriptions.append(async_results.pop(0).get()) + pbar.update(1) + + yield { + 'progress': round(len(transcriptions) / len(segments) * 100, 2), + 'remaining_time': (pbar.total - pbar.n) / pbar.format_dict['rate'] + if pbar.format_dict['rate'] and pbar.total + else None, + } shutil.rmtree(temp_directory) return transcriptions - @minimum_execution_time(min(4, multiprocessing.cpu_count() - 1) + 1) def _process_segment( self, segment: tuple[str, float, float], file_path: str, wit_config: Config.Wit, session: requests.Session, + wit_client_access_token_index: int, ) -> dict[str, Union[str, float]]: + wit_calling_throttle.throttle(wit_client_access_token_index) + segment_file_path, start, end = segment with open(segment_file_path, 'rb') as wav_file: @@ -91,25 +121,26 @@ def _process_segment( text = '' while retries > 0: - response = session.post( - 'https://api.wit.ai/speech', - headers={ - 'Accept': 'application/vnd.wit.20200513+json', - 'Content-Type': 'audio/wav', - 'Authorization': f'Bearer {wit_config.wit_client_access_token}', - }, - data=audio_content, - ) - - if response.status_code == 200: - try: + try: + response = session.post( + 'https://api.wit.ai/speech', + headers={ + 'Accept': 'application/vnd.wit.20200513+json', + 'Content-Type': 'audio/wav', + 'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_client_access_token_index]}', + }, + data=audio_content, + ) + + if response.status_code == 200: text = json.loads(response.text)['text'] break - except KeyError: + else: retries -= 1 - else: + time.sleep(self.processes_per_wit_client_access_token + 1) + except: retries -= 1 - time.sleep(min(4, multiprocessing.cpu_count() - 1) + 1) + time.sleep(self.processes_per_wit_client_access_token + 1) if retries == 0: logging.warn( diff --git a/tafrigh/types/whisper/type_hints.py b/tafrigh/types/whisper/type_hints.py index 17bbaa0..989f55d 100644 --- a/tafrigh/types/whisper/type_hints.py +++ b/tafrigh/types/whisper/type_hints.py @@ -4,6 +4,7 @@ import whisper import whisper_jax + WhisperModel = TypeVar( 'WhisperModel', whisper.Whisper, diff --git a/tafrigh/utils/cli_utils.py b/tafrigh/utils/cli_utils.py index 2af426c..33c53d4 100644 --- a/tafrigh/utils/cli_utils.py +++ b/tafrigh/utils/cli_utils.py @@ -3,6 +3,7 @@ from tafrigh.types.transcript_type import TranscriptType + PLAYLIST_ITEMS_RE = re.compile( r'''(?x) (?P