Skip to content

Commit

Permalink
Implement different throttling for single wit.ai keys
Browse files Browse the repository at this point in the history
  • Loading branch information
AliOsm committed Aug 25, 2023
1 parent 3760bf5 commit 98dbe28
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 8 deletions.
48 changes: 40 additions & 8 deletions tafrigh/recognizers/wit_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tafrigh.audio_splitter import AudioSplitter
from tafrigh.config import Config
from tafrigh.recognizers.wit_calling_throttle import WitCallingThrottle, WitCallingThrottleManager
from tafrigh.utils.decorators import minimum_execution_time


def init_pool(throttle: WitCallingThrottle) -> None:
Expand Down Expand Up @@ -61,30 +62,39 @@ def recognize(
multiprocessing.cpu_count(),
)

transcriptions = []

if len(wit_config.wit_client_access_tokens) == 1:
process_segment_function = self._process_segment_single_key
extra_args = lambda _index: ()
pool_initializer = None
else:
process_segment_function = self._process_segment_multiple_keys
extra_args = lambda index: (index % len(wit_config.wit_client_access_tokens),)
pool_initializer = init_pool

with WitCallingThrottleManager() as manager:
wit_calling_throttle = manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens))
multiple_keys_pool_initargs = (manager.WitCallingThrottle(len(wit_config.wit_client_access_tokens)),)

with multiprocessing.Pool(
processes=pool_processes_count,
initializer=init_pool,
initargs=(wit_calling_throttle,),
initializer=pool_initializer,
initargs=multiple_keys_pool_initargs if pool_initializer else (),
) as pool:
async_results = [
pool.apply_async(
self._process_segment,
process_segment_function,
(
segment,
file_path,
wit_config,
session,
index % len(wit_config.wit_client_access_tokens),
*extra_args(index),
),
)
for index, segment in enumerate(segments)
]

transcriptions = []

with tqdm(total=len(segments), disable=self.verbose is not False) as pbar:
while async_results:
if async_results[0].ready():
Expand All @@ -98,11 +108,23 @@ def recognize(
else None,
}

time.sleep(0.5)

shutil.rmtree(temp_directory)

return transcriptions

def _process_segment(
@minimum_execution_time(min(4, multiprocessing.cpu_count()) + 0.5)
def _process_segment_single_key(
self,
segment: tuple[str, float, float],
file_path: str,
wit_config: Config.Wit,
session: requests.Session,
) -> dict[str, Union[str, float]]:
return self._process_segment(segment, file_path, wit_config, session, 0)

def _process_segment_multiple_keys(
self,
segment: tuple[str, float, float],
file_path: str,
Expand All @@ -112,6 +134,16 @@ def _process_segment(
) -> dict[str, Union[str, float]]:
wit_calling_throttle.throttle(wit_client_access_token_index)

return self._process_segment(segment, file_path, wit_config, session, wit_client_access_token_index)

def _process_segment(
self,
segment: tuple[str, float, float],
file_path: str,
wit_config: Config.Wit,
session: requests.Session,
wit_client_access_token_index: int,
) -> dict[str, Union[str, float]]:
segment_file_path, start, end = segment

with open(segment_file_path, 'rb') as wav_file:
Expand Down
26 changes: 26 additions & 0 deletions tafrigh/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import time

from functools import wraps
from typing import Callable, TypeVar


T = TypeVar('T', bound=Callable)


def minimum_execution_time(minimum_time: float) -> Callable[[T], T]:
def decorator(func: T) -> T:
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()

elapsed_time = end_time - start_time
if elapsed_time < minimum_time:
time.sleep(minimum_time - elapsed_time)

return result

return wrapper

return decorator

0 comments on commit 98dbe28

Please sign in to comment.