Skip to content

Commit

Permalink
Bring back Realtime SRT generation, check if WAV provided, minor code…
Browse files Browse the repository at this point in the history
… improvements
  • Loading branch information
Sirozha1337 committed Feb 16, 2024
1 parent ac311a0 commit 8ea16d3
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 52 deletions.
15 changes: 9 additions & 6 deletions auto_subtitle/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import warnings
import logging
from typing import Optional
from .models.Subtitles import Subtitles
from .models.Subtitles import Subtitles, SegmentsIterable
from .utils.files import filename, write_srt
from .utils.ffmpeg import get_audio, add_subtitles
from .utils.ffmpeg import get_audio, add_subtitles, preprocess_audio
from .utils.whisper import WhisperAI
from .translation.easynmt_utils import EasyNMTWrapper

Expand Down Expand Up @@ -52,7 +52,10 @@ def process(args: dict):

os.makedirs(output_args["output_dir"], exist_ok=True)
for video in videos:
audio = get_audio(video, audio_channel, sample_interval)
if video.endswith('.wav'):
audio = preprocess_audio(video, audio_channel, sample_interval)
else:
audio = get_audio(video, audio_channel, sample_interval)

transcribed, translated = perform_task(video, audio, language, target_language,
transcribe_model, translate_model)
Expand Down Expand Up @@ -97,9 +100,9 @@ def translate_subtitles(subtitles: Subtitles, source_lang: str, target_lang: str
src_lang = subtitles.language

translated_segments = model.translate(
subtitles.segments, src_lang, target_lang)
list(subtitles.segments), src_lang, target_lang)

return Subtitles(translated_segments, target_lang)
return Subtitles(SegmentsIterable(translated_segments), target_lang)


def save_subtitles(path: str, subtitles: Subtitles, output_dir: str,
Expand All @@ -122,4 +125,4 @@ def get_subtitles(source_path: str, audio_path: str, model: WhisperAI) -> Subtit

segments, info = model.transcribe(audio_path)

return Subtitles(segments=list(segments), language=info.language)
return Subtitles(segments=SegmentsIterable(segments), language=info.language)
46 changes: 43 additions & 3 deletions auto_subtitle/models/Subtitles.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,52 @@
from typing import Optional
from typing import Optional, Iterable, Iterator
from faster_whisper.transcribe import Segment


class SegmentsIterable(Iterable):
segments: list = None
index: int = None
length: int = 0
__segments_iterable: Iterator[Segment]

def __init__(self, segments: Iterable[Segment]):
self.__segments_iterable = segments.__iter__()
self.segments = []

def __iter__(self):
if self.index is not None:
self.index = 0
return self

def __next_list(self):
if self.index < self.length:
item = self.segments[self.index]
self.index += 1
return item
else:
raise StopIteration

def __next_iter(self):
try:
item = next(self.__segments_iterable)
self.segments.append(item)
self.length += 1
return item
except StopIteration:
self.index = 0
raise StopIteration

def __next__(self):
if self.index is not None:
return self.__next_list()

return self.__next_iter()


class Subtitles:
segments: list
segments: SegmentsIterable[Segment]
language: str
output_path: Optional[str] = None

def __init__(self, segments: list[Segment], language: str):
def __init__(self, segments: SegmentsIterable[Segment], language: str):
self.language = language
self.segments = segments
12 changes: 8 additions & 4 deletions auto_subtitle/translation/easynmt_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from easynmt import EasyNMT
from faster_whisper.transcribe import Segment
from .opusmt_utils import OpusMT
Expand All @@ -11,12 +12,15 @@ def __init__(self, device: str):
device=device if device != 'auto' else None)

def translate(self, segments: list[Segment], source_lang: str,
target_lang: str) -> list[Segment]:
target_lang: str) -> Optional[list[Segment]]:
source_text = [segment.text for segment in segments]
self.translator.load_available_models()
translation_available = self.translator.prepare_translation(source_lang, target_lang)
if not translation_available:
return None

translated_text = self.model.translate(source_text, target_lang,
source_lang, show_progress_bar=True)
translated_text = self.model.translate(source_text, target_lang, source_lang,
document_language_detection=False,
show_progress_bar=True)
translated_segments = []
for segment, translation in zip(segments, translated_text):
translated_segments.append(segment._replace(text=translation))
Expand Down
10 changes: 5 additions & 5 deletions auto_subtitle/translation/languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
from transformers.models.marian.convert_marian_tatoeba_to_pytorch import GROUP_MEMBERS


def to_alpha2_languages(languages):
def to_alpha2_languages(languages: list[str]) -> set[str]:
return set(item for sublist in [__to_alpha2_language(language) for language in languages] for item in sublist)


def __to_alpha2_language(language):
def __to_alpha2_language(language: str) -> set[str]:
if len(language) == 2:
return [language]
return {language}

if language in GROUP_MEMBERS:
return set([langcodes.Language.get(x).language for x in GROUP_MEMBERS[language][1]])

return [langcodes.Language.get(language).language]
return {langcodes.Language.get(language).language}


def to_alpha3_language(language):
def to_alpha3_language(language: str) -> str:
return langcodes.Language.get(language).to_alpha3()
82 changes: 51 additions & 31 deletions auto_subtitle/translation/opusmt_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import logging
from typing import List
from typing import List, Optional
import torch
from huggingface_hub import list_models, ModelFilter
from transformers import MarianMTModel, MarianTokenizer
Expand All @@ -13,14 +13,13 @@

class OpusMT:
def __init__(self, max_loaded_models: int = 10):
self.models = {}
self.max_loaded_models = max_loaded_models
self.max_length = None
self.models: dict = {}
self.max_loaded_models: int = max_loaded_models
self.max_length: Optional[int] = None
self.available_models: Optional[dict] = None
self.prepared_translations: dict = {}

self.available_models = None
self.translations_graph = None

def load_model(self, model_name):
def load_model(self, model_name: str) -> tuple:
if model_name in self.models:
self.models[model_name]['last_loaded'] = time.time()
return self.models[model_name]['tokenizer'], self.models[model_name]['model']
Expand All @@ -43,14 +42,15 @@ def load_model(self, model_name):
'tokenizer': tokenizer, 'model': model, 'last_loaded': time.time()}
return tokenizer, model

def load_available_models(self):
def load_available_models(self) -> None:
if self.available_models is not None:
return

logger.info('Loading a list of available language models from OPUS-NT')
logger.info('Loading a list of available language models from OPUS-MT')
model_list = list_models(
filter=ModelFilter(
author=NLP_ROOT
author=NLP_ROOT,
model_name='opus-mt'
)
)

Expand All @@ -64,14 +64,32 @@ def load_available_models(self):
for model in models:
for src in model.source_languages:
for tgt in model.target_languages:
key = f'{src}-{tgt}'
key = self.make_translation_key(src, tgt)
if key not in self.available_models:
self.available_models[key] = model
elif self.available_models[key].language_count > model.language_count:
self.available_models[key] = model

def determine_required_translations(self, source_lang, target_lang):
direct_key = f'{source_lang}-{target_lang}'
@staticmethod
def make_translation_key(source_lang: str, target_lang: str) -> str:
return f'{source_lang}-{target_lang}'

def prepare_translation(self, source_lang: str, target_lang: str) -> bool:
self.load_available_models()

translation_key = self.make_translation_key(source_lang, target_lang)
if translation_key in self.prepared_translations:
return self.prepared_translations[translation_key]

translations = self.determine_required_translations(source_lang, target_lang)
if len(translations) == 0:
return False

self.prepared_translations[translation_key] = translations
return True

def determine_required_translations(self, source_lang: str, target_lang: str) -> List[tuple]:
direct_key = self.make_translation_key(source_lang, target_lang)
if direct_key in self.available_models:
logger.info(
'Found direct translation from %s to %s.', source_lang, target_lang)
Expand All @@ -81,46 +99,48 @@ def determine_required_translations(self, source_lang, target_lang):
'No direct translation from %s to %s. Trying to translate through en.', source_lang,
target_lang)

to_en_key = f'{source_lang}-en'
to_en_key = self.make_translation_key(source_lang, 'en')
if to_en_key not in self.available_models:
logger.info('No translation from %s to en.', source_lang)
logger.warning('No translation from %s to en.', source_lang)
return []

from_en_key = f'en-{target_lang}'
from_en_key = self.make_translation_key('en', target_lang)
if from_en_key not in self.available_models:
logger.info('No translation from en to %s.', target_lang)
logger.warning('No translation from en to %s.', target_lang)
return []

return [(source_lang, 'en', to_en_key), ('en', target_lang, from_en_key)]

def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str,
device: str, beam_size: int = 5, **kwargs):
self.load_available_models()

translations = self.determine_required_translations(
source_lang, target_lang)
device: str, beam_size: int = 5, **kwargs) -> List[str]:

if len(translations) == 0:
return sentences
translations = []
translation_key = self.make_translation_key(source_lang, target_lang)
if translation_key in self.prepared_translations:
translations = self.prepared_translations[translation_key]
else:
logger.warning('prepare_translation method should be called prior to '
'translate_sentences')

intermediate = sentences
for _, tgt_lang, key in translations:
for _, intermediate_target_language, key in translations:
model_data = self.available_models[key]
model_name = model_data.name
tokenizer, model = self.load_model(model_name)
model.to(device)

# MultiLanguage model requires prepending each line with target language
if model_data.multilanguage:
alpha3 = to_alpha3_language(tgt_lang)
alpha3 = to_alpha3_language(intermediate_target_language)
prefix = next(
x for x in tokenizer.supported_language_codes if alpha3 in x)
intermediate = [f'{prefix} {x}' for x in intermediate]

inputs = tokenizer(intermediate, truncation=True, padding=True,
max_length=self.max_length, return_tensors="pt")

for key in inputs:
inputs[key] = inputs[key].to(device)
for token in inputs:
inputs[token] = inputs[token].to(device)

with torch.no_grad():
translated = model.generate(
Expand All @@ -132,7 +152,7 @@ def translate_sentences(self, sentences: List[str], source_lang: str, target_lan


class DownloadableModel:
def __init__(self, name):
def __init__(self, name: str):
self.name = name
source_languages, target_languages = self.parse_languages(name)
self.source_languages = source_languages
Expand All @@ -142,7 +162,7 @@ def __init__(self, name):
self.source_languages) + len(self.target_languages)

@staticmethod
def parse_languages(name):
def parse_languages(name: str) -> tuple[set, set]:
parts = name.split('-')
if len(parts) > 5:
return set(), set()
Expand Down
18 changes: 17 additions & 1 deletion auto_subtitle/utils/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)


def get_audio(path: str, audio_channel_index: int, sample_interval: list) -> str:
def get_audio(path: str, audio_channel_index: int, sample_interval: Optional[list] = None) -> str:
temp_dir = tempfile.gettempdir()

file_name = filename(path)
Expand Down Expand Up @@ -39,6 +39,22 @@ def get_audio(path: str, audio_channel_index: int, sample_interval: list) -> str
return output_path


def preprocess_audio(path: str, audio_channel_index: int, sample_interval: Optional[list]) -> str:
if sample_interval is not None or audio_channel_index != 0:
return get_audio(path, audio_channel_index, sample_interval)

audio_info = ffmpeg.probe(path, select_streams='a')
audio_format = audio_info['format']
audio_streams = audio_info['streams']
if audio_format['format_name'] == 'wav' and \
audio_streams is not None and len(audio_streams) == 1:
audio_stream = audio_streams[0]
if audio_stream['codec_name'] == 'pcm_s16le' and audio_stream['sample_rate'] == '16000':
return path

return get_audio(path, audio_channel_index)


def add_subtitles(path: str, transcribed: Subtitles, translated: Optional[Subtitles],
sample_interval: list, output_args: dict[str, str]) -> None:
file_name = filename(path)
Expand Down
4 changes: 2 additions & 2 deletions auto_subtitle/utils/files.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import TextIO
from typing import TextIO, Iterator
from faster_whisper.transcribe import Segment
from .convert import format_timestamp


def write_srt(transcript: list[Segment], file: TextIO) -> None:
def write_srt(transcript: Iterator[Segment], file: TextIO) -> None:
for i, segment in enumerate(transcript, start=1):
print(
f"{i}\n"
Expand Down

0 comments on commit 8ea16d3

Please sign in to comment.