Skip to content

Commit

Permalink
refactor: place default options into base_options
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Sep 17, 2023
1 parent 439c7e7 commit 2143486
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
16 changes: 11 additions & 5 deletions capgen/transcriber/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import BinaryIO, Literal
from typing import BinaryIO, Literal, TypedDict

from faster_whisper import WhisperModel

from capgen.transcriber.converter import Converter
from capgen.types import TranscriberOptions


class Transcriber:
Expand All @@ -16,7 +17,13 @@ class Transcriber:
transcribe(file: str | BinaryIO, caption_format: Literal['srt']) -> str:
converts transcription segments into a SRT file
"""
model = WhisperModel('large-v2', device="cpu", compute_type='auto', num_workers=4)
base_options = TranscriberOptions(
model_size_or_path='guillaumekln/faster-whisper-large-v2',
compute_type='auto',
num_workers=4,
)

model = WhisperModel(**base_options, device='cpu')

@classmethod
def toggle_device(cls):
Expand All @@ -26,9 +33,8 @@ def toggle_device(cls):
toggles the device between CPU and GPU
"""
cls.model = WhisperModel(
'guillaumekln/faster-whisper-large-v2',
device="cpu" if cls.model.model.device == "cuda" else "cuda",
compute_type='auto'
**cls.base_options,
device='cpu' if cls.model.model.device == 'cuda' else 'cuda',
)


Expand Down
1 change: 1 addition & 0 deletions capgen/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from capgen.types.options import TranscriberOptions as TranscriberOptions
19 changes: 19 additions & 0 deletions capgen/types/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import TypedDict


class TranscriberOptions(TypedDict):
"""
Summary
-------
the options for the transcriber
Attributes
----------
model_size_or_path (str) : the model size or path
compute_type (str) : the compute type
num_workers (int) : the number of workers
"""
model_size_or_path: str
compute_type: str
num_workers: int

0 comments on commit 2143486

Please sign in to comment.