Skip to content

Commit

Permalink
whisperx word level supprt #37
Browse files Browse the repository at this point in the history
  • Loading branch information
abdeladim-s committed Jun 26, 2023
1 parent 3625e1d commit 57ebd6b
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions src/subsai/models/whisperX_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import whisperx
from subsai.utils import _load_config, get_available_devices
import gc

from pysubs2 import SSAFile, SSAEvent

class WhisperXModel(AbstractModel):
model_name = 'm-bain/whisperX'
Expand Down Expand Up @@ -54,6 +54,13 @@ class WhisperXModel(AbstractModel):
'options': None,
'default': None
},
'segment_type': {
'type': list,
'description': "Word-level timestamps, "
"Choose here between sentence-level and word-level",
'options': ['sentence', 'word'],
'default': 'sentence'
},
# transcribe config
'batch_size': {
'type': int,
Expand Down Expand Up @@ -103,6 +110,7 @@ def __init__(self, model_config):
self.compute_type = _load_config('compute_type', model_config, self.config_schema)
self.download_root = _load_config('download_root', model_config, self.config_schema)
self.language = _load_config('language', model_config, self.config_schema)
self.segment_type = _load_config('segment_type', model_config, self.config_schema)
# transcribe config
self.batch_size = _load_config('batch_size', model_config, self.config_schema)
self.return_char_alignments = _load_config('return_char_alignments', model_config, self.config_schema)
Expand Down Expand Up @@ -130,7 +138,22 @@ def transcribe(self, media_file) -> str:
result = whisperx.assign_word_speakers(diarize_segments, result)
self._clear_gpu()
del diarize_model
subs = pysubs2.load_from_whisper(result)

subs = SSAFile()
if self.segment_type == 'word': # word level timestamps
for segment in result['segments']:
for word in segment['words']:
event = SSAEvent(start=pysubs2.make_time(s=word["start"]), end=pysubs2.make_time(s=word["end"]))
event.plaintext = word["word"].strip()
subs.append(event)
elif self.segment_type == 'sentence':
for segment in result['segments']:
event = SSAEvent(start=pysubs2.make_time(s=segment["start"]), end=pysubs2.make_time(s=segment["end"]))
event.plaintext = segment["text"].strip()
subs.append(event)
else:
raise Exception(f'Unknown `segment_type` value, it should be one of the following: '
f' {self.config_schema["segment_type"]["options"]}')
return subs

def _clear_gpu(self):
Expand Down

0 comments on commit 57ebd6b

Please sign in to comment.