Skip to content

Commit

Permalink
Use multiple wit.ai API keys to transcript the same video
Browse files Browse the repository at this point in the history
  • Loading branch information
AliOsm committed Jul 30, 2023
1 parent 5a4ac3b commit 9fd1299
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion colab_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
" beam_size=5,\n",
" ct2_compute_type='default',\n",
"\n",
" wit_client_access_token=wit_api_key,\n",
" wit_client_access_tokens=[wit_api_key],\n",
" max_cutting_duration=max_cutting_duration,\n",
" min_words_per_segment=min_words_per_segment,\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion tafrigh/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
beam_size=args.beam_size,
ct2_compute_type=args.ct2_compute_type,

wit_client_access_token=args.wit_client_access_token,
wit_client_access_tokens=args.wit_client_access_tokens,
max_cutting_duration=args.max_cutting_duration,
min_words_per_segment=args.min_words_per_segment,

Expand Down
10 changes: 5 additions & 5 deletions tafrigh/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
use_whisper_jax: bool,
beam_size: int,
ct2_compute_type: str,
wit_client_access_token: str,
wit_client_access_tokens: List[str],
max_cutting_duration: int,
min_words_per_segment: int,
save_files_before_compact: bool,
Expand All @@ -40,7 +40,7 @@ def __init__(
ct2_compute_type,
)

self.wit = self.Wit(wit_client_access_token, max_cutting_duration)
self.wit = self.Wit(wit_client_access_tokens, max_cutting_duration)

self.output = self.Output(
min_words_per_segment,
Expand All @@ -52,7 +52,7 @@ def __init__(
)

def use_wit(self) -> bool:
return self.wit.wit_client_access_token != ''
return self.wit.wit_client_access_tokens != []

class Input:
def __init__(self, urls_or_paths: List[str], skip_if_output_exist: bool, playlist_items: str, verbose: bool):
Expand Down Expand Up @@ -85,8 +85,8 @@ def __init__(
self.ct2_compute_type = ct2_compute_type

class Wit:
def __init__(self, wit_client_access_token: str, max_cutting_duration: int):
self.wit_client_access_token = wit_client_access_token
def __init__(self, wit_client_access_tokens: List[str], max_cutting_duration: int):
self.wit_client_access_tokens = wit_client_access_tokens
self.max_cutting_duration = max_cutting_duration

class Output:
Expand Down
18 changes: 14 additions & 4 deletions tafrigh/recognizers/wit_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,19 @@ def recognize(
session = requests.Session()
session.mount('https://', adapter)

with multiprocessing.Pool(processes=min(4, multiprocessing.cpu_count() - 1)) as pool:
with multiprocessing.Pool(processes=min(4, multiprocessing.cpu_count() - 1) * len(wit_config.wit_client_access_tokens)) as pool:
async_results = [
pool.apply_async(self._process_segment, (segment, file_path, wit_config, session))
for segment in segments
pool.apply_async(
self._process_segment,
(
segment,
file_path,
wit_config,
session,
((index + 1) / (len(segments) / len(wit_config.wit_client_access_tokens))) - 1,
),
)
for index, segment in enumerate(segments)
]

transcriptions = []
Expand Down Expand Up @@ -82,6 +91,7 @@ def _process_segment(
file_path: str,
wit_config: Config.Wit,
session: requests.Session,
wit_api_key_index: int,
) -> Dict[str, Union[str, float]]:
segment_file_path, start, end = segment

Expand All @@ -97,7 +107,7 @@ def _process_segment(
headers={
'Accept': 'application/vnd.wit.20200513+json',
'Content-Type': 'audio/wav',
'Authorization': f'Bearer {wit_config.wit_client_access_token}',
'Authorization': f'Bearer {wit_config.wit_client_access_tokens[wit_api_key_index]}',
},
data=audio_content,
)
Expand Down
8 changes: 4 additions & 4 deletions tafrigh/utils/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def parse_args(argv: List[str]) -> argparse.Namespace:

wit_group = parser.add_argument_group('Wit')

wit_group.add_argument(
input_group.add_argument(
'-w',
'--wit_client_access_token',
default='',
help='wit.ai client access token. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.',
'--wit_client_access_tokens',
nargs='+',
help='List of wit.ai client access tokens. If provided, wit.ai APIs will be used to do the transcription, otherwise whisper will be used.',
)

wit_group.add_argument(
Expand Down

0 comments on commit 9fd1299

Please sign in to comment.