Skip to content

Commit 35f7b97

Browse files
committed
Add CTranslate2
1 parent 42ee041 commit 35f7b97

File tree

21 files changed

+791
-76
lines changed

21 files changed

+791
-76
lines changed

pipeline/common/command_runner.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import os
12
import re
23
from shlex import join
34
import shlex
45
import subprocess
5-
from typing import Union
66

77

88
def _get_indented_command_string(command_parts: list[str]) -> str:
@@ -97,7 +97,9 @@ def run_command_pipeline(
9797
subprocess.check_call(command_string, shell=True)
9898

9999

100-
def run_command(command: list[str], capture=False, logger=None) -> str | None:
100+
def run_command(
101+
command: list[str], capture=False, shell=False, logger=None, env=None
102+
) -> str | None:
101103
"""
102104
Runs a command and outputs a nice representation of the command to a logger, if supplied.
103105
@@ -107,13 +109,17 @@ def run_command(command: list[str], capture=False, logger=None) -> str | None:
107109
pipeline. If False, output is printed to stdout.
108110
logger: A logger instance used for logging the command execution. If provided,
109111
it will log the pipeline commands.
112+
env: The environment object.
110113
111114
Example:
112115
directory_listing = run_command(
113116
["ls", "-l"],
114117
capture=True
115118
)
116119
"""
120+
# Expand any environment variables.
121+
command = [os.path.expandvars(part) for part in command]
122+
117123
if logger:
118124
# Log out a nice representation of this command.
119125
logger.info("Running:")
@@ -123,30 +129,4 @@ def run_command(command: list[str], capture=False, logger=None) -> str | None:
123129
if capture:
124130
return subprocess.check_output(command).decode("utf-8")
125131

126-
subprocess.check_call(command)
127-
128-
129-
def marian_args_to_dict(extra_marian_args: list[str]) -> dict[str, Union[str, list[str]]]:
130-
"""
131-
Converts marian args, to the dict format.
132-
133-
e.g. `--precision float16` becomes {"precision": "float16"}
134-
"""
135-
decoder_config = {}
136-
key = None
137-
for arg in extra_marian_args:
138-
if arg.startswith("--"):
139-
key = arg[2:]
140-
elif key:
141-
existing_arg = decoder_config.get(key)
142-
if existing_arg is None:
143-
decoder_config[key] = arg
144-
elif isinstance(existing_arg, list):
145-
existing_arg.append(arg)
146-
else:
147-
# Convert these arguments into a list, since there are multiple
148-
decoder_config[key] = [existing_arg, arg]
149-
else:
150-
raise ValueError("Marian args should start with a --key")
151-
152-
return decoder_config
132+
subprocess.check_call(command, env=env)

pipeline/common/marian.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
Common utilities related to working with Marian.
3+
"""
4+
5+
from pathlib import Path
6+
7+
import yaml
8+
9+
10+
def get_combined_config(config_path: Path, extra_marian_args: list[str]) -> dict[str, any]:
11+
"""
12+
Frequently we combine a Marian yml config with extra marian args when running
13+
training. To get the final value, add both here.
14+
"""
15+
return {
16+
**yaml.safe_load(config_path.open()),
17+
**marian_args_to_dict(extra_marian_args),
18+
}
19+
20+
21+
def marian_args_to_dict(extra_marian_args: list[str]) -> dict:
22+
"""
23+
Converts marian args, to the dict format. This will combine a decoder.yml
24+
and extra marian args.
25+
26+
e.g. `--precision float16` becomes {"precision": "float16"}
27+
"""
28+
decoder_config = {}
29+
if extra_marian_args and extra_marian_args[0] == "--":
30+
extra_marian_args = extra_marian_args[1:]
31+
32+
previous_key = None
33+
for arg in extra_marian_args:
34+
if arg.startswith("--"):
35+
previous_key = arg[2:]
36+
decoder_config[previous_key] = True
37+
continue
38+
39+
if not previous_key:
40+
raise Exception(
41+
f"Expected to have a previous key when converting marian args to a dict: {extra_marian_args}"
42+
)
43+
44+
prev_value = decoder_config.get(previous_key)
45+
if prev_value is True:
46+
decoder_config[previous_key] = arg
47+
elif isinstance(prev_value, list):
48+
prev_value.append(arg)
49+
else:
50+
decoder_config[previous_key] = [prev_value, arg]
51+
52+
return decoder_config

pipeline/eval/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
list_existing_group_logs_metrics,
6666
)
6767

68-
WANDB_AVAILABLE = True
68+
WANDB_AVAILABLE = "TASKCLUSTER_PROXY_URL " in os.environ
6969
except ImportError as e:
7070
print(f"Failed to import tracking module: {e}")
7171
WANDB_AVAILABLE = False

pipeline/translate/translate.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,38 @@
33
"""
44

55
import argparse
6+
from enum import Enum
67
from glob import glob
8+
import os
79
from pathlib import Path
810
import tempfile
911

12+
1013
from pipeline.common.command_runner import apply_command_args, run_command
1114
from pipeline.common.datasets import compress, decompress
1215
from pipeline.common.downloads import count_lines, is_file_empty, write_lines
1316
from pipeline.common.logging import get_logger
17+
from pipeline.common.marian import get_combined_config
18+
from pipeline.translate.translate_ctranslate2 import translate_with_ctranslate2
1419

1520
logger = get_logger(__file__)
1621

22+
DECODER_CONFIG_PATH = Path(__file__).parent / "decoder.yml"
23+
24+
25+
class Decoder(Enum):
26+
marian = "marian"
27+
ctranslate2 = "ctranslate2"
28+
29+
30+
class Device(Enum):
31+
cpu = "cpu"
32+
gpu = "gpu"
33+
34+
35+
def get_beam_size(extra_marian_args: list[str]):
36+
return get_combined_config(DECODER_CONFIG_PATH, extra_marian_args)["beam-size"]
37+
1738

1839
def run_marian(
1940
marian_dir: Path,
@@ -30,7 +51,7 @@ def run_marian(
3051
marian_bin = str(marian_dir / "marian-decoder")
3152
log = input.parent / f"{input.name}.log"
3253
if is_nbest:
33-
extra_args = ["--nbest", *extra_args]
54+
extra_args = ["--n-best", *extra_args]
3455

3556
logger.info("Starting Marian to translate")
3657

@@ -52,6 +73,7 @@ def run_marian(
5273
*extra_args,
5374
],
5475
logger=logger,
76+
env={**os.environ},
5577
)
5678

5779

@@ -69,6 +91,7 @@ def main() -> None:
6991
"--models_glob",
7092
type=str,
7193
required=True,
94+
nargs="+",
7295
help="A glob pattern to the Marian model(s)",
7396
)
7497
parser.add_argument(
@@ -91,6 +114,18 @@ def main() -> None:
91114
required=True,
92115
help="The amount of Marian memory (in MB) to preallocate",
93116
)
117+
parser.add_argument(
118+
"--decoder",
119+
type=Decoder,
120+
default=Decoder.marian,
121+
help="Either use the normal marian decoder, or opt for CTranslate2.",
122+
)
123+
parser.add_argument(
124+
"--device",
125+
type=Device,
126+
default=Device.gpu,
127+
help="Either use the normal marian decoder, or opt for CTranslate2.",
128+
)
94129
parser.add_argument(
95130
"extra_marian_args",
96131
nargs=argparse.REMAINDER,
@@ -103,13 +138,19 @@ def main() -> None:
103138
marian_dir: Path = args.marian_dir
104139
input_zst: Path = args.input
105140
artifacts: Path = args.artifacts
106-
models_glob: str = args.models_glob
107-
models: list[Path] = [Path(path) for path in glob(models_glob)]
141+
models_globs: str = args.models_glob
142+
models: list[Path] = []
143+
for models_glob in models_globs:
144+
for path in glob(models_glob):
145+
models.append(Path(path))
108146
postfix = "nbest" if args.nbest else "out"
109147
output_zst = artifacts / f"{input_zst.stem}.{postfix}.zst"
110148
vocab: Path = args.vocab
111149
gpus: list[str] = args.gpus.split(" ")
112150
extra_marian_args: list[str] = args.extra_marian_args
151+
decoder: Decoder = args.decoder
152+
is_nbest: bool = args.nbest
153+
device: Device = args.device
113154

114155
# Do some light validation of the arguments.
115156
assert input_zst.exists(), f"The input file exists: {input_zst}"
@@ -118,6 +159,7 @@ def main() -> None:
118159
artifacts.mkdir()
119160
for gpu_index in gpus:
120161
assert gpu_index.isdigit(), f'GPUs must be list of numbers: "{gpu_index}"'
162+
assert models, "There must be at least one model"
121163
for model in models:
122164
assert model.exists(), f"The model file exists {model}"
123165
if extra_marian_args and extra_marian_args[0] != "--":
@@ -136,6 +178,29 @@ def main() -> None:
136178
pass
137179
return
138180

181+
if decoder == Decoder.ctranslate2:
182+
translate_with_ctranslate2(
183+
input_zst=input_zst,
184+
artifacts=artifacts,
185+
extra_marian_args=extra_marian_args,
186+
models_glob=models_glob,
187+
is_nbest=is_nbest,
188+
vocab=[str(vocab)],
189+
device=device.value,
190+
)
191+
return
192+
193+
# The device flag is for use with CTranslate, but add some assertions here so that
194+
# we can be consistent in usage.
195+
if device == Device.cpu:
196+
assert (
197+
"--cpu-threads" in extra_marian_args
198+
), "Marian's cpu should be controlled with the flag --cpu-threads"
199+
else:
200+
assert (
201+
"--cpu-threads" not in extra_marian_args
202+
), "Requested a GPU device, but --cpu-threads was provided"
203+
139204
# Run the training.
140205
with tempfile.TemporaryDirectory() as temp_dir_str:
141206
temp_dir = Path(temp_dir_str)
@@ -152,16 +217,26 @@ def main() -> None:
152217
output=output_txt,
153218
gpus=gpus,
154219
workspace=args.workspace,
155-
is_nbest=args.nbest,
220+
is_nbest=is_nbest,
156221
# Take off the initial "--"
157222
extra_args=extra_marian_args[1:],
158223
)
159-
assert count_lines(input_txt) == count_lines(
160-
output_txt
161-
), "The input and output had the same number of lines"
162224

163225
compress(output_txt, destination=output_zst, remove=True, logger=logger)
164226

227+
input_count = count_lines(input_txt)
228+
output_count = count_lines(output_zst)
229+
if is_nbest:
230+
beam_size = get_beam_size(extra_marian_args)
231+
expected_output = input_count * beam_size
232+
assert (
233+
expected_output == output_count
234+
), f"The nbest output had {beam_size}x as many lines ({expected_output} vs {output_count})"
235+
else:
236+
assert (
237+
input_count == output_count
238+
), f"The input ({input_count} and output ({output_count}) had the same number of lines"
239+
165240

166241
if __name__ == "__main__":
167242
main()

0 commit comments

Comments
 (0)