Skip to content

Commit

Permalink
model_family #40
Browse files Browse the repository at this point in the history
  • Loading branch information
abdeladim-s committed Jun 23, 2023
1 parent f8d64f1 commit 3391fbe
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/subsai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,23 @@ def available_translation_languages(model: Union[str, TranslationModel]) -> list
return langs

@staticmethod
def create_translation_model(model_name: str = "m2m100") -> TranslationModel:
def create_translation_model(model_name: str = "m2m100", model_family: str = None) -> TranslationModel:
"""
Creates and returns a translation model instance.
:param model_name: name of the model. To get available models use :func:`available_translation_models`
:return:
:param model_family: Either "mbart50" or "m2m100". By default, See `dl-translate` docs
:return: A translation model instance
"""
mt = TranslationModel(model_name)
mt = TranslationModel(model_or_path=model_name, model_family=model_family)
return mt

@staticmethod
def translate(subs: SSAFile,
source_language: str,
target_language: str,
model: Union[str, TranslationModel] = "m2m100",
model_family: str = None,
translation_configs: dict = {}) -> SSAFile:
"""
Translates a subtitles `SSAFile` object, what :func:`SubsAI.transcribe` is returning
Expand All @@ -172,12 +174,13 @@ def translate(subs: SSAFile,
:param target_language: the target language
:param model: the translation model, either an `str` or the model instance created by
:func:`create_translation_model`
:param model_family: Either "mbart50" or "m2m100". By default, See `dl-translate` docs
:param translation_configs: dict of translation configs (see :attr:`configs.ADVANCED_TOOLS_CONFIGS`)
:return: returns an `SSAFile` subtitles translated to the target language
"""
if type(model) == str:
translation_model = Tools.create_translation_model(model)
translation_model = Tools.create_translation_model(model_name=model, model_family=model_family)
else:
translation_model = model

Expand Down

0 comments on commit 3391fbe

Please sign in to comment.