From 8ea16d3eb907f58b826fc37c0e2067e36ff6dabe Mon Sep 17 00:00:00 2001 From: Sergey Chernyaev Date: Fri, 16 Feb 2024 19:28:28 +0100 Subject: [PATCH] Bring back Realtime SRT generation, check if WAV provided, minor code improvements --- auto_subtitle/main.py | 15 ++-- auto_subtitle/models/Subtitles.py | 46 +++++++++++- auto_subtitle/translation/easynmt_utils.py | 12 ++-- auto_subtitle/translation/languages.py | 10 +-- auto_subtitle/translation/opusmt_utils.py | 82 ++++++++++++++-------- auto_subtitle/utils/ffmpeg.py | 18 ++++- auto_subtitle/utils/files.py | 4 +- 7 files changed, 135 insertions(+), 52 deletions(-) diff --git a/auto_subtitle/main.py b/auto_subtitle/main.py index ae18a8b..171dd27 100644 --- a/auto_subtitle/main.py +++ b/auto_subtitle/main.py @@ -2,9 +2,9 @@ import warnings import logging from typing import Optional -from .models.Subtitles import Subtitles +from .models.Subtitles import Subtitles, SegmentsIterable from .utils.files import filename, write_srt -from .utils.ffmpeg import get_audio, add_subtitles +from .utils.ffmpeg import get_audio, add_subtitles, preprocess_audio from .utils.whisper import WhisperAI from .translation.easynmt_utils import EasyNMTWrapper @@ -52,7 +52,10 @@ def process(args: dict): os.makedirs(output_args["output_dir"], exist_ok=True) for video in videos: - audio = get_audio(video, audio_channel, sample_interval) + if video.endswith('.wav'): + audio = preprocess_audio(video, audio_channel, sample_interval) + else: + audio = get_audio(video, audio_channel, sample_interval) transcribed, translated = perform_task(video, audio, language, target_language, transcribe_model, translate_model) @@ -97,9 +100,9 @@ def translate_subtitles(subtitles: Subtitles, source_lang: str, target_lang: str src_lang = subtitles.language translated_segments = model.translate( - subtitles.segments, src_lang, target_lang) + list(subtitles.segments), src_lang, target_lang) - return Subtitles(translated_segments, target_lang) + return Subtitles(SegmentsIterable(translated_segments), target_lang) def save_subtitles(path: str, subtitles: Subtitles, output_dir: str, @@ -122,4 +125,4 @@ def get_subtitles(source_path: str, audio_path: str, model: WhisperAI) -> Subtit segments, info = model.transcribe(audio_path) - return Subtitles(segments=list(segments), language=info.language) + return Subtitles(segments=SegmentsIterable(segments), language=info.language) diff --git a/auto_subtitle/models/Subtitles.py b/auto_subtitle/models/Subtitles.py index 32d6f83..79cbddf 100644 --- a/auto_subtitle/models/Subtitles.py +++ b/auto_subtitle/models/Subtitles.py @@ -1,12 +1,52 @@ -from typing import Optional +from typing import Optional, Iterable, Iterator from faster_whisper.transcribe import Segment +class SegmentsIterable(Iterable): + segments: list = None + index: int = None + length: int = 0 + __segments_iterable: Iterator[Segment] + + def __init__(self, segments: Iterable[Segment]): + self.__segments_iterable = segments.__iter__() + self.segments = [] + + def __iter__(self): + if self.index is not None: + self.index = 0 + return self + + def __next_list(self): + if self.index < self.length: + item = self.segments[self.index] + self.index += 1 + return item + else: + raise StopIteration + + def __next_iter(self): + try: + item = next(self.__segments_iterable) + self.segments.append(item) + self.length += 1 + return item + except StopIteration: + self.index = 0 + raise StopIteration + + def __next__(self): + if self.index is not None: + return self.__next_list() + + return self.__next_iter() + + class Subtitles: - segments: list + segments: SegmentsIterable[Segment] language: str output_path: Optional[str] = None - def __init__(self, segments: list[Segment], language: str): + def __init__(self, segments: SegmentsIterable[Segment], language: str): self.language = language self.segments = segments diff --git a/auto_subtitle/translation/easynmt_utils.py b/auto_subtitle/translation/easynmt_utils.py index b38e55c..9eff464 100644 --- a/auto_subtitle/translation/easynmt_utils.py +++ b/auto_subtitle/translation/easynmt_utils.py @@ -1,3 +1,4 @@ +from typing import Optional from easynmt import EasyNMT from faster_whisper.transcribe import Segment from .opusmt_utils import OpusMT @@ -11,12 +12,15 @@ def __init__(self, device: str): device=device if device != 'auto' else None) def translate(self, segments: list[Segment], source_lang: str, - target_lang: str) -> list[Segment]: + target_lang: str) -> Optional[list[Segment]]: source_text = [segment.text for segment in segments] - self.translator.load_available_models() + translation_available = self.translator.prepare_translation(source_lang, target_lang) + if not translation_available: + return None - translated_text = self.model.translate(source_text, target_lang, - source_lang, show_progress_bar=True) + translated_text = self.model.translate(source_text, target_lang, source_lang, + document_language_detection=False, + show_progress_bar=True) translated_segments = [] for segment, translation in zip(segments, translated_text): translated_segments.append(segment._replace(text=translation)) diff --git a/auto_subtitle/translation/languages.py b/auto_subtitle/translation/languages.py index 1e4eec5..f883edb 100644 --- a/auto_subtitle/translation/languages.py +++ b/auto_subtitle/translation/languages.py @@ -2,19 +2,19 @@ from transformers.models.marian.convert_marian_tatoeba_to_pytorch import GROUP_MEMBERS -def to_alpha2_languages(languages): +def to_alpha2_languages(languages: list[str]) -> set[str]: return set(item for sublist in [__to_alpha2_language(language) for language in languages] for item in sublist) -def __to_alpha2_language(language): +def __to_alpha2_language(language: str) -> set[str]: if len(language) == 2: - return [language] + return {language} if language in GROUP_MEMBERS: return set([langcodes.Language.get(x).language for x in GROUP_MEMBERS[language][1]]) - return [langcodes.Language.get(language).language] + return {langcodes.Language.get(language).language} -def to_alpha3_language(language): +def to_alpha3_language(language: str) -> str: return langcodes.Language.get(language).to_alpha3() diff --git a/auto_subtitle/translation/opusmt_utils.py b/auto_subtitle/translation/opusmt_utils.py index 35b7621..ac165f3 100644 --- a/auto_subtitle/translation/opusmt_utils.py +++ b/auto_subtitle/translation/opusmt_utils.py @@ -1,6 +1,6 @@ import time import logging -from typing import List +from typing import List, Optional import torch from huggingface_hub import list_models, ModelFilter from transformers import MarianMTModel, MarianTokenizer @@ -13,14 +13,13 @@ class OpusMT: def __init__(self, max_loaded_models: int = 10): - self.models = {} - self.max_loaded_models = max_loaded_models - self.max_length = None + self.models: dict = {} + self.max_loaded_models: int = max_loaded_models + self.max_length: Optional[int] = None + self.available_models: Optional[dict] = None + self.prepared_translations: dict = {} - self.available_models = None - self.translations_graph = None - - def load_model(self, model_name): + def load_model(self, model_name: str) -> tuple: if model_name in self.models: self.models[model_name]['last_loaded'] = time.time() return self.models[model_name]['tokenizer'], self.models[model_name]['model'] @@ -43,14 +42,15 @@ def load_model(self, model_name): 'tokenizer': tokenizer, 'model': model, 'last_loaded': time.time()} return tokenizer, model - def load_available_models(self): + def load_available_models(self) -> None: if self.available_models is not None: return - logger.info('Loading a list of available language models from OPUS-NT') + logger.info('Loading a list of available language models from OPUS-MT') model_list = list_models( filter=ModelFilter( - author=NLP_ROOT + author=NLP_ROOT, + model_name='opus-mt' ) ) @@ -64,14 +64,32 @@ def load_available_models(self): for model in models: for src in model.source_languages: for tgt in model.target_languages: - key = f'{src}-{tgt}' + key = self.make_translation_key(src, tgt) if key not in self.available_models: self.available_models[key] = model elif self.available_models[key].language_count > model.language_count: self.available_models[key] = model - def determine_required_translations(self, source_lang, target_lang): - direct_key = f'{source_lang}-{target_lang}' + @staticmethod + def make_translation_key(source_lang: str, target_lang: str) -> str: + return f'{source_lang}-{target_lang}' + + def prepare_translation(self, source_lang: str, target_lang: str) -> bool: + self.load_available_models() + + translation_key = self.make_translation_key(source_lang, target_lang) + if translation_key in self.prepared_translations: + return self.prepared_translations[translation_key] + + translations = self.determine_required_translations(source_lang, target_lang) + if len(translations) == 0: + return False + + self.prepared_translations[translation_key] = translations + return True + + def determine_required_translations(self, source_lang: str, target_lang: str) -> List[tuple]: + direct_key = self.make_translation_key(source_lang, target_lang) if direct_key in self.available_models: logger.info( 'Found direct translation from %s to %s.', source_lang, target_lang) @@ -81,37 +99,39 @@ def determine_required_translations(self, source_lang, target_lang): 'No direct translation from %s to %s. Trying to translate through en.', source_lang, target_lang) - to_en_key = f'{source_lang}-en' + to_en_key = self.make_translation_key(source_lang, 'en') if to_en_key not in self.available_models: - logger.info('No translation from %s to en.', source_lang) + logger.warning('No translation from %s to en.', source_lang) return [] - from_en_key = f'en-{target_lang}' + from_en_key = self.make_translation_key('en', target_lang) if from_en_key not in self.available_models: - logger.info('No translation from en to %s.', target_lang) + logger.warning('No translation from en to %s.', target_lang) return [] return [(source_lang, 'en', to_en_key), ('en', target_lang, from_en_key)] def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, - device: str, beam_size: int = 5, **kwargs): - self.load_available_models() - - translations = self.determine_required_translations( - source_lang, target_lang) + device: str, beam_size: int = 5, **kwargs) -> List[str]: - if len(translations) == 0: - return sentences + translations = [] + translation_key = self.make_translation_key(source_lang, target_lang) + if translation_key in self.prepared_translations: + translations = self.prepared_translations[translation_key] + else: + logger.warning('prepare_translation method should be called prior to ' + 'translate_sentences') intermediate = sentences - for _, tgt_lang, key in translations: + for _, intermediate_target_language, key in translations: model_data = self.available_models[key] model_name = model_data.name tokenizer, model = self.load_model(model_name) model.to(device) + # MultiLanguage model requires prepending each line with target language if model_data.multilanguage: - alpha3 = to_alpha3_language(tgt_lang) + alpha3 = to_alpha3_language(intermediate_target_language) prefix = next( x for x in tokenizer.supported_language_codes if alpha3 in x) intermediate = [f'{prefix} {x}' for x in intermediate] @@ -119,8 +139,8 @@ def translate_sentences(self, sentences: List[str], source_lang: str, target_lan inputs = tokenizer(intermediate, truncation=True, padding=True, max_length=self.max_length, return_tensors="pt") - for key in inputs: - inputs[key] = inputs[key].to(device) + for token in inputs: + inputs[token] = inputs[token].to(device) with torch.no_grad(): translated = model.generate( @@ -132,7 +152,7 @@ def translate_sentences(self, sentences: List[str], source_lang: str, target_lan class DownloadableModel: - def __init__(self, name): + def __init__(self, name: str): self.name = name source_languages, target_languages = self.parse_languages(name) self.source_languages = source_languages @@ -142,7 +162,7 @@ def __init__(self, name): self.source_languages) + len(self.target_languages) @staticmethod - def parse_languages(name): + def parse_languages(name: str) -> tuple[set, set]: parts = name.split('-') if len(parts) > 5: return set(), set() diff --git a/auto_subtitle/utils/ffmpeg.py b/auto_subtitle/utils/ffmpeg.py index 0bdb433..c2f8893 100644 --- a/auto_subtitle/utils/ffmpeg.py +++ b/auto_subtitle/utils/ffmpeg.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -def get_audio(path: str, audio_channel_index: int, sample_interval: list) -> str: +def get_audio(path: str, audio_channel_index: int, sample_interval: Optional[list] = None) -> str: temp_dir = tempfile.gettempdir() file_name = filename(path) @@ -39,6 +39,22 @@ def get_audio(path: str, audio_channel_index: int, sample_interval: list) -> str return output_path +def preprocess_audio(path: str, audio_channel_index: int, sample_interval: Optional[list]) -> str: + if sample_interval is not None or audio_channel_index != 0: + return get_audio(path, audio_channel_index, sample_interval) + + audio_info = ffmpeg.probe(path, select_streams='a') + audio_format = audio_info['format'] + audio_streams = audio_info['streams'] + if audio_format['format_name'] == 'wav' and \ + audio_streams is not None and len(audio_streams) == 1: + audio_stream = audio_streams[0] + if audio_stream['codec_name'] == 'pcm_s16le' and audio_stream['sample_rate'] == '16000': + return path + + return get_audio(path, audio_channel_index) + + def add_subtitles(path: str, transcribed: Subtitles, translated: Optional[Subtitles], sample_interval: list, output_args: dict[str, str]) -> None: file_name = filename(path) diff --git a/auto_subtitle/utils/files.py b/auto_subtitle/utils/files.py index 1737002..c720ff6 100644 --- a/auto_subtitle/utils/files.py +++ b/auto_subtitle/utils/files.py @@ -1,10 +1,10 @@ import os -from typing import TextIO +from typing import TextIO, Iterator from faster_whisper.transcribe import Segment from .convert import format_timestamp -def write_srt(transcript: list[Segment], file: TextIO) -> None: +def write_srt(transcript: Iterator[Segment], file: TextIO) -> None: for i, segment in enumerate(transcript, start=1): print( f"{i}\n"