Skip to content

Commit

Permalink
whisperX diarization #65
Browse files Browse the repository at this point in the history
  • Loading branch information
abdeladim-s committed Aug 25, 2023
1 parent db9a35b commit 9ae401c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "subsai"
version = "1.1.2"
version = "1.1.3"
authors = [
{name = "abdeladim-s"},
]
Expand Down
15 changes: 10 additions & 5 deletions src/subsai/models/whisperX_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import gc
from pysubs2 import SSAFile, SSAEvent


class WhisperXModel(AbstractModel):
model_name = 'm-bain/whisperX'
config_schema = {
Expand Down Expand Up @@ -103,7 +104,7 @@ class WhisperXModel(AbstractModel):

def __init__(self, model_config):
super(WhisperXModel, self).__init__(model_config=model_config,
model_name=self.model_name)
model_name=self.model_name)
# config
self.model_type = _load_config('model_type', model_config, self.config_schema)
self.device = _load_config('device', model_config, self.config_schema)
Expand All @@ -129,7 +130,8 @@ def transcribe(self, media_file) -> str:
audio = whisperx.load_audio(media_file)
result = self.model.transcribe(audio, batch_size=self.batch_size)
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=self.device)
result = whisperx.align(result["segments"], model_a, metadata, audio, self.device, return_char_alignments=self.return_char_alignments)
result = whisperx.align(result["segments"], model_a, metadata, audio, self.device,
return_char_alignments=self.return_char_alignments)
self._clear_gpu()
del model_a
if self.speaker_labels:
Expand All @@ -140,11 +142,13 @@ def transcribe(self, media_file) -> str:
del diarize_model

subs = SSAFile()

if self.segment_type == 'word': # word level timestamps
for segment in result['segments']:
for word in segment['words']:
try:
event = SSAEvent(start=pysubs2.make_time(s=word["start"]), end=pysubs2.make_time(s=word["end"]))
event = SSAEvent(start=pysubs2.make_time(s=word["start"]), end=pysubs2.make_time(s=word["end"]),
name=segment["speaker"] if self.speaker_labels else "")
event.plaintext = word["word"].strip()
subs.append(event)
except Exception as e:
Expand All @@ -153,7 +157,8 @@ def transcribe(self, media_file) -> str:

elif self.segment_type == 'sentence':
for segment in result['segments']:
event = SSAEvent(start=pysubs2.make_time(s=segment["start"]), end=pysubs2.make_time(s=segment["end"]))
event = SSAEvent(start=pysubs2.make_time(s=segment["start"]), end=pysubs2.make_time(s=segment["end"]),
name=segment["speaker"] if self.speaker_labels else "")
event.plaintext = segment["text"].strip()
subs.append(event)
else:
Expand All @@ -163,4 +168,4 @@ def transcribe(self, media_file) -> str:

def _clear_gpu(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.empty_cache()

0 comments on commit 9ae401c

Please sign in to comment.