Skip to content

Commit

Permalink
Add support for ctranslate2
Browse files Browse the repository at this point in the history
  • Loading branch information
gregtatum committed Nov 13, 2024
1 parent 69ca812 commit 0edc3c6
Show file tree
Hide file tree
Showing 14 changed files with 370 additions and 3 deletions.
50 changes: 50 additions & 0 deletions pipeline/translate/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import argparse
from enum import Enum
from glob import glob
import os
from pathlib import Path
Expand All @@ -13,12 +14,23 @@
from pipeline.common.downloads import count_lines, is_file_empty, write_lines
from pipeline.common.logging import get_logger
from pipeline.common.marian import get_combined_config
from pipeline.translate.translate_ctranslate2 import translate_with_ctranslate2

logger = get_logger(__file__)

DECODER_CONFIG_PATH = Path(__file__).parent / "decoder.yml"


class Decoder(Enum):
marian = "marian"
ctranslate2 = "ctranslate2"


class Device(Enum):
cpu = "cpu"
gpu = "gpu"


def get_beam_size(extra_marian_args: list[str]):
return get_combined_config(DECODER_CONFIG_PATH, extra_marian_args)["beam-size"]

Expand Down Expand Up @@ -101,6 +113,18 @@ def main() -> None:
required=True,
help="The amount of Marian memory (in MB) to preallocate",
)
parser.add_argument(
"--decoder",
type=Decoder,
default=Decoder.marian,
help="Either use the normal marian decoder, or opt for CTranslate2.",
)
parser.add_argument(
"--device",
type=Device,
default=Device.gpu,
help="Either use the normal marian decoder, or opt for CTranslate2.",
)
parser.add_argument(
"extra_marian_args",
nargs=argparse.REMAINDER,
Expand All @@ -123,7 +147,9 @@ def main() -> None:
vocab: Path = args.vocab
gpus: list[str] = args.gpus.split(" ")
extra_marian_args: list[str] = args.extra_marian_args
decoder: Decoder = args.decoder
is_nbest: bool = args.nbest
device: Device = args.device

# Do some light validation of the arguments.
assert input_zst.exists(), f"The input file exists: {input_zst}"
Expand Down Expand Up @@ -151,6 +177,30 @@ def main() -> None:
pass
return

if decoder == Decoder.ctranslate2:
translate_with_ctranslate2(
input_zst=input_zst,
artifacts=artifacts,
extra_marian_args=extra_marian_args,
models_globs=models_globs,
is_nbest=is_nbest,
vocab=[str(vocab)],
device=device.value,
device_index=[int(n) for n in gpus],
)
return

# The device flag is for use with CTranslate, but add some assertions here so that
# we can be consistent in usage.
if device == Device.cpu:
assert (
"--cpu-threads" in extra_marian_args
), "Marian's cpu should be controlled with the flag --cpu-threads"
else:
assert (
"--cpu-threads" not in extra_marian_args
), "Requested a GPU device, but --cpu-threads was provided"

# Run the training.
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
Expand Down
186 changes: 186 additions & 0 deletions pipeline/translate/translate_ctranslate2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Translate a corpus with a teacher model (transformer-based) using CTranslate2. This is useful
to quickly synthesize training data for student distillation as CTranslate2 is X (TODO) times faster
than Marian.
https://github.com/OpenNMT/CTranslate2/
"""

from typing import Any, TextIO
from enum import Enum
from glob import glob
from pathlib import Path

import ctranslate2
import sentencepiece as spm
from ctranslate2.converters.marian import MarianConverter

from pipeline.common.downloads import read_lines, write_lines
from pipeline.common.logging import get_logger
from pipeline.common.marian import get_combined_config


def load_vocab(path: str):
logger.info("Loading vocab:")
logger.info(path)
sp = spm.SentencePieceProcessor(path)

return [sp.id_to_piece(i) for i in range(sp.vocab_size())]


# The vocab expects a .yml file. Instead directly load the vocab .spm file via a monkey patch.
if not ctranslate2.converters.marian.load_vocab:
raise Exception("Expected to be able to monkey patch the load_vocab function")
ctranslate2.converters.marian.load_vocab = load_vocab

logger = get_logger(__file__)


class Device(Enum):
gpu = "gpu"
cpu = "cpu"


class MaxiBatchSort(Enum):
src = "src"
none = "none"


def get_model(models_globs: list[str]) -> Path:
models: list[Path] = []
for models_glob in models_globs:
for path in glob(models_glob):
models.append(Path(path))
if not models:
raise ValueError(f'No model was found with the glob "{models_glob}"')
if len(models) != 1:
logger.info(f"Found models {models}")
raise ValueError("Ensemble training is not supported in CTranslate2")
return Path(models[0])


class DecoderConfig:
def __init__(self, extra_marian_args: list[str]) -> None:
super().__init__()
# Combine the two configs.
self.config = get_combined_config(Path(__file__).parent / "decoder.yml", extra_marian_args)

self.mini_batch_words: int = self.get_from_config("mini-batch-words", int)
self.beam_size: int = self.get_from_config("beam-size", int)
self.precision = self.get_from_config("precision", str)

def get_from_config(self, key: str, type: any):
value = self.config.get(key, None)
if value is None:
raise ValueError(f'"{key}" could not be found in the decoder.yml config')
if isinstance(value, type):
return value
if type == int and isinstance(value, str):
return int(value)
raise ValueError(f'Expected "{key}" to be of a type "{type}" in the decoder.yml config')


def write_single_translation(
_index: int, tokenizer_trg: spm.SentencePieceProcessor, result: Any, outfile: TextIO
):
"""
Just write each single translation to a new line. If beam search was used all the other
beam results are discarded.
"""
line = tokenizer_trg.decode(result.hypotheses[0])
outfile.write(line)
outfile.write("\n")


def write_nbest_translations(
index: int, tokenizer_trg: spm.SentencePieceProcessor, result: Any, outfile: TextIO
):
"""
Match Marian's way of writing out nbest translations. For example, with a beam-size of 2 and
collection nbest translations:
0 ||| Translation attempt
0 ||| An attempt at translation
1 ||| The quick brown fox jumped
1 ||| The brown fox quickly jumped
...
"""
outfile.write(index)
for hypothesis in result.hypotheses:
line = tokenizer_trg.decode(hypothesis)
outfile.write(f"{index} ||| {line}\n")


def translate_with_ctranslate2(
input_zst: Path,
artifacts: Path,
extra_marian_args: list[str],
models_globs: list[str],
is_nbest: bool,
vocab: list[str],
device: str,
device_index: list[int],
) -> None:
model = get_model(models_globs)
postfix = "nbest" if is_nbest else "out"
assert not is_nbest, "TODO - nbest is not supported yet"

tokenizer_src = spm.SentencePieceProcessor(vocab[0])
if len(vocab) == 1:
tokenizer_trg = tokenizer_src
else:
tokenizer_trg = spm.SentencePieceProcessor(vocab[1])

if extra_marian_args and extra_marian_args[0] != "--":
logger.error(" ".join(extra_marian_args))
raise Exception("Expected the extra marian args to be after a --")

decoder_config = DecoderConfig(extra_marian_args[1:])

ctranslate2_model_dir = model.parent / f"{Path(model).stem}"
logger.info("Converting the Marian model to Ctranslate2:")
logger.info(model)
logger.info("Outputing model to:")
logger.info(ctranslate2_model_dir)

converter = MarianConverter(model, vocab)
converter.convert(ctranslate2_model_dir, quantization=decoder_config.precision)

if device == "gpu":
translator = ctranslate2.Translator(
str(ctranslate2_model_dir), device="cuda", device_index=device_index
)
else:
translator = ctranslate2.Translator(str(ctranslate2_model_dir), device="cpu")

logger.info("Loading model")
translator.load_model()
logger.info("Model loaded")

output_zst = artifacts / f"{input_zst.stem}.{postfix}.zst"

num_hypotheses = 1
write_translation = write_single_translation
if is_nbest:
num_hypotheses = decoder_config.beam_size
write_translation = write_nbest_translations

def tokenize(line):
return tokenizer_src.Encode(line.strip(), out_type=str)

index = 0
with write_lines(output_zst) as outfile, read_lines(input_zst) as lines:
for result in translator.translate_iterable(
# Options for "translate_iterable":
# https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_iterable
map(tokenize, lines),
max_batch_size=decoder_config.mini_batch_words,
batch_type="tokens",
# Options for "translate_batch":
# https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_batch
beam_size=decoder_config.beam_size,
return_scores=False,
num_hypotheses=num_hypotheses,
):
write_translation(index, tokenizer_trg, result, outfile)
index += 1
1 change: 1 addition & 0 deletions taskcluster/configs/config.ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ experiment:
use-opuscleaner: "true"
opuscleaner-mode: "custom"
teacher-mode: "two-stage"
teacher-decoder: marian
corpus-max-sentences: 1000
student-model: "tiny"

Expand Down
2 changes: 2 additions & 0 deletions taskcluster/configs/config.prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ experiment:
# Switch to "one-stage" training if back-translations are produced by a high quality model or
# the model stops too early on the fine-tuning stage
teacher-mode: "two-stage"
# Translate with either Marian, or CTranslate2.
teacher-decoder: marian
# Two student training configurations from Bergamot are supported: "tiny" and "base"
# "base" model is twice slower and larger but adds ~2 COMET points in quality (see https://github.com/mozilla/translations/issues/174)
student-model: "tiny"
Expand Down
1 change: 1 addition & 0 deletions taskcluster/kinds/translate-corpus/kind.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ tasks:
# double curly braces are used for the chunk substitutions because
# this must first be formatted by task-context to get src and trg locale
- >-
pip3 install -r $VCS_PATH/pipeline/translate/requirements/translate-ctranslate2.txt &&
export PYTHONPATH=$PYTHONPATH:$VCS_PATH &&
python3 $VCS_PATH/pipeline/translate/translate.py
--input "$MOZ_FETCHES_DIR/file.{{this_chunk}}.zst"
Expand Down
6 changes: 6 additions & 0 deletions taskcluster/kinds/translate-mono-src/kind.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ tasks:
type: translate-mono-src
resources:
- pipeline/translate/translate.py
- pipeline/translate/translate_ctranslate2.py
- pipeline/translate/requirements/translate-ctranslate2.txt
from-parameters:
split_chunks: training_config.taskcluster.split-chunks
marian_args: training_config.marian-args.decoding-teacher
teacher_decoder: training_config.experiment.teacher-decoder

task-context:
from-parameters:
Expand All @@ -50,6 +53,7 @@ tasks:
best_model: training_config.experiment.best-model
locale: training_config.experiment.src
split_chunks: training_config.taskcluster.split-chunks
teacher_decoder: training_config.experiment.teacher-decoder
substitution-fields:
- chunk.total-chunks
- description
Expand Down Expand Up @@ -113,6 +117,7 @@ tasks:
- bash
- -xc
- >-
pip3 install -r $VCS_PATH/pipeline/translate/requirements/translate-ctranslate2.txt &&
export PYTHONPATH=$PYTHONPATH:$VCS_PATH &&
python3 $VCS_PATH/pipeline/translate/translate.py
--input "$MOZ_FETCHES_DIR/file.{{this_chunk}}.zst"
Expand All @@ -122,6 +127,7 @@ tasks:
--marian_dir "$MARIAN"
--gpus "$GPUS"
--workspace "$WORKSPACE"
--decoder "{teacher_decoder}"
--
{marian_args}
Expand Down
6 changes: 6 additions & 0 deletions taskcluster/kinds/translate-mono-trg/kind.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ tasks:
type: translate-mono-trg
resources:
- pipeline/translate/translate.py
- pipeline/translate/translate_ctranslate2.py
- pipeline/translate/requirements/translate-ctranslate2.txt
from-parameters:
split_chunks: training_config.taskcluster.split-chunks
marian_args: training_config.marian-args.decoding-backward
teacher_decoder: training_config.experiment.teacher-decoder

task-context:
from-parameters:
Expand All @@ -48,6 +51,7 @@ tasks:
best_model: training_config.experiment.best-model
locale: training_config.experiment.trg
split_chunks: training_config.taskcluster.split-chunks
teacher_decoder: training_config.experiment.teacher-decoder
substitution-fields:
- chunk.total-chunks
- description
Expand Down Expand Up @@ -113,6 +117,7 @@ tasks:
# double curly braces are used for the chunk substitutions because
# this must first be formatted by task-context to get src and trg locale
- >-
pip3 install -r $VCS_PATH/pipeline/translate/requirements/translate-ctranslate2.txt &&
export PYTHONPATH=$PYTHONPATH:$VCS_PATH &&
python3 $VCS_PATH/pipeline/translate/translate.py
--input "$MOZ_FETCHES_DIR/file.{{this_chunk}}.zst"
Expand All @@ -122,5 +127,6 @@ tasks:
--marian_dir "$MARIAN"
--gpus "$GPUS"
--workspace "$WORKSPACE"
--decoder "marian"
--
{marian_args}
1 change: 1 addition & 0 deletions taskcluster/test/params/large-lt-en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ training_config:
src: lt
teacher-ensemble: 2
teacher-mode: 'two-stage'
teacher-decoder: marian
student-model: 'base'
trg: en
use-opuscleaner: 'false'
Expand Down
1 change: 1 addition & 0 deletions taskcluster/test/params/small-ru-en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ training_config:
src: ru
teacher-ensemble: 1
teacher-mode: 'two-stage'
teacher-decoder: marian
student-model: 'tiny'
trg: en
use-opuscleaner: 'true'
Expand Down
Loading

0 comments on commit 0edc3c6

Please sign in to comment.