diff --git a/src/subsai/models/whisperX_model.py b/src/subsai/models/whisperX_model.py index 8c03c0c..4b910ae 100644 --- a/src/subsai/models/whisperX_model.py +++ b/src/subsai/models/whisperX_model.py @@ -16,7 +16,7 @@ import whisperx from subsai.utils import _load_config, get_available_devices import gc - +from pysubs2 import SSAFile, SSAEvent class WhisperXModel(AbstractModel): model_name = 'm-bain/whisperX' @@ -54,6 +54,13 @@ class WhisperXModel(AbstractModel): 'options': None, 'default': None }, + 'segment_type': { + 'type': list, + 'description': "Word-level timestamps, " + "Choose here between sentence-level and word-level", + 'options': ['sentence', 'word'], + 'default': 'sentence' + }, # transcribe config 'batch_size': { 'type': int, @@ -103,6 +110,7 @@ def __init__(self, model_config): self.compute_type = _load_config('compute_type', model_config, self.config_schema) self.download_root = _load_config('download_root', model_config, self.config_schema) self.language = _load_config('language', model_config, self.config_schema) + self.segment_type = _load_config('segment_type', model_config, self.config_schema) # transcribe config self.batch_size = _load_config('batch_size', model_config, self.config_schema) self.return_char_alignments = _load_config('return_char_alignments', model_config, self.config_schema) @@ -130,7 +138,22 @@ def transcribe(self, media_file) -> str: result = whisperx.assign_word_speakers(diarize_segments, result) self._clear_gpu() del diarize_model - subs = pysubs2.load_from_whisper(result) + + subs = SSAFile() + if self.segment_type == 'word': # word level timestamps + for segment in result['segments']: + for word in segment['words']: + event = SSAEvent(start=pysubs2.make_time(s=word["start"]), end=pysubs2.make_time(s=word["end"])) + event.plaintext = word["word"].strip() + subs.append(event) + elif self.segment_type == 'sentence': + for segment in result['segments']: + event = SSAEvent(start=pysubs2.make_time(s=segment["start"]), end=pysubs2.make_time(s=segment["end"])) + event.plaintext = segment["text"].strip() + subs.append(event) + else: + raise Exception(f'Unknown `segment_type` value, it should be one of the following: ' + f' {self.config_schema["segment_type"]["options"]}') return subs def _clear_gpu(self):