Skip to content

Commit

Permalink
feat: use async to run CapGen in background thread
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Jun 3, 2024
1 parent 6872951 commit 413ca25
Show file tree
Hide file tree
Showing 6 changed files with 682 additions and 641 deletions.
24 changes: 10 additions & 14 deletions capgen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from argparse import ArgumentParser
from asyncio import run
from ctypes import CDLL
from os import name
from os.path import join
Expand Down Expand Up @@ -29,29 +30,24 @@ def parse_args() -> Arguments | None:
"""
parser = ArgumentParser(description='transcribe a compatible audio/video file into a chosen caption file format')
parser.add_argument('file', nargs='?', type=str, help='the file path to a compatible audio/video')
parser.add_argument('-g', '--cuda', action='store_true', help='whether to use CUDA for inference')
parser.add_argument('-g', '--cuda', action='store_true', help='whether to use CUDA for inference')

cpu_group = parser.add_argument_group('cpu')
cpu_group.add_argument('-t', '--threads', metavar='', type=int, help='the number of CPU threads')
cpu_group.add_argument('-w', '--workers', metavar='', type=int, help='the number of CPU workers')

required_group = parser.add_argument_group('required')
required_group.add_argument('-c', '--caption', type=str, required=True, metavar='', help='the chosen caption file format')
required_group.add_argument('-o', '--output', type=str, required=True, metavar='', help='the output file path')
required_group.add_argument(
'-c', '--caption', type=str, required=True, metavar='', help='the chosen caption file format'
)
required_group.add_argument('-o', '--output', type=str, required=True, metavar='', help='the output file path')

args, unknown = parser.parse_known_args()

if unknown or not args.file and stdin.isatty():
return parser.print_help()

return Arguments(
args.file or stdin.buffer,
args.caption,
args.output,
args.cuda,
args.threads,
args.workers
)
return Arguments(args.file or stdin.buffer, args.caption, args.output, args.cuda, args.threads, args.workers)


def resolve_cuda_libraries():
Expand Down Expand Up @@ -83,7 +79,7 @@ def resolve_cuda_libraries():
print('Unable to find Python cuBLAS binaries, falling back to system binaries..')


def main():
async def main():
"""
Summary
-------
Expand All @@ -104,12 +100,12 @@ def main():
options['device'] = 'cuda'
resolve_cuda_libraries()

if not (transcription := Transcriber(**options).transcribe(args.file, args.caption)):
if not (transcription := await Transcriber(**options).transcribe(args.file, args.caption)):
raise InvalidFormatError(f'Invalid format: {args.caption}!')

with open(args.output, 'w', encoding='utf-8') as file:
file.write(transcription)


if __name__ == '__main__':
main()
run(main())
42 changes: 31 additions & 11 deletions capgen/transcriber/transcriber.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from typing import BinaryIO, Literal
from typing import BinaryIO, Literal, TypedDict

from faster_whisper import WhisperModel

from capgen.transcriber.converter import Converter


class WhisperParameters(TypedDict):
"""
Summary
-------
a type hint for the parameters of the WhisperModel class
"""

model_size_or_path: str
device: str
compute_type: str
cpu_threads: int
num_workers: int


class Transcriber:
"""
Summary
Expand All @@ -16,25 +30,30 @@ class Transcriber:
transcribe(file: str | BinaryIO, caption_format: str) -> str | None:
converts transcription segments into a SRT file
"""

__slots__ = ('model',)

def __init__(
self,
device: Literal['auto', 'cpu', 'cuda'],
number_of_threads: int = 0,
number_of_workers: int = 1
number_of_workers: int = 1,
):
model_parameters: WhisperParameters = {
'model_size_or_path': 'whisper-medium-en-til-ct2',
'device': device,
'compute_type': 'auto',
'cpu_threads': number_of_threads,
'num_workers': number_of_workers,
}

self.model = WhisperModel(
'distil-whisper/distil-large-v3-ct2',
device,
compute_type='auto',
cpu_threads=number_of_threads,
num_workers=number_of_workers
)
try:
self.model = WhisperModel(**model_parameters, flash_attention=True)

except ValueError:
self.model = WhisperModel(**model_parameters)

def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None:
async def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None:
"""
Summary
-------
Expand All @@ -51,8 +70,9 @@ def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None:
"""
segments, _ = self.model.transcribe(
file,
beam_size=1,
vad_filter=True,
vad_parameters={ 'min_silence_duration_ms': 500 }
vad_parameters={'min_silence_duration_ms': 500},
)

converter = Converter(segments)
Expand Down
Loading

0 comments on commit 413ca25

Please sign in to comment.