3
3
"""
4
4
5
5
import argparse
6
+ from enum import Enum
6
7
from glob import glob
8
+ import os
7
9
from pathlib import Path
8
10
import tempfile
9
11
12
+
10
13
from pipeline .common .command_runner import apply_command_args , run_command
11
14
from pipeline .common .datasets import compress , decompress
12
15
from pipeline .common .downloads import count_lines , is_file_empty , write_lines
13
16
from pipeline .common .logging import get_logger
17
+ from pipeline .common .marian import get_combined_config
18
+ from pipeline .translate .translate_ctranslate2 import translate_with_ctranslate2
14
19
15
20
logger = get_logger (__file__ )
16
21
22
+ DECODER_CONFIG_PATH = Path (__file__ ).parent / "decoder.yml"
23
+
24
+
25
+ class Decoder (Enum ):
26
+ marian = "marian"
27
+ ctranslate2 = "ctranslate2"
28
+
29
+
30
+ class Device (Enum ):
31
+ cpu = "cpu"
32
+ gpu = "gpu"
33
+
34
+
35
+ def get_beam_size (extra_marian_args : list [str ]):
36
+ return get_combined_config (DECODER_CONFIG_PATH , extra_marian_args )["beam-size" ]
37
+
17
38
18
39
def run_marian (
19
40
marian_dir : Path ,
@@ -30,7 +51,7 @@ def run_marian(
30
51
marian_bin = str (marian_dir / "marian-decoder" )
31
52
log = input .parent / f"{ input .name } .log"
32
53
if is_nbest :
33
- extra_args = ["--nbest " , * extra_args ]
54
+ extra_args = ["--n-best " , * extra_args ]
34
55
35
56
logger .info ("Starting Marian to translate" )
36
57
@@ -52,6 +73,7 @@ def run_marian(
52
73
* extra_args ,
53
74
],
54
75
logger = logger ,
76
+ env = {** os .environ },
55
77
)
56
78
57
79
@@ -69,6 +91,7 @@ def main() -> None:
69
91
"--models_glob" ,
70
92
type = str ,
71
93
required = True ,
94
+ nargs = "+" ,
72
95
help = "A glob pattern to the Marian model(s)" ,
73
96
)
74
97
parser .add_argument (
@@ -91,6 +114,18 @@ def main() -> None:
91
114
required = True ,
92
115
help = "The amount of Marian memory (in MB) to preallocate" ,
93
116
)
117
+ parser .add_argument (
118
+ "--decoder" ,
119
+ type = Decoder ,
120
+ default = Decoder .marian ,
121
+ help = "Either use the normal marian decoder, or opt for CTranslate2." ,
122
+ )
123
+ parser .add_argument (
124
+ "--device" ,
125
+ type = Device ,
126
+ default = Device .gpu ,
127
+ help = "Either use the normal marian decoder, or opt for CTranslate2." ,
128
+ )
94
129
parser .add_argument (
95
130
"extra_marian_args" ,
96
131
nargs = argparse .REMAINDER ,
@@ -103,13 +138,19 @@ def main() -> None:
103
138
marian_dir : Path = args .marian_dir
104
139
input_zst : Path = args .input
105
140
artifacts : Path = args .artifacts
106
- models_glob : str = args .models_glob
107
- models : list [Path ] = [Path (path ) for path in glob (models_glob )]
141
+ models_globs : str = args .models_glob
142
+ models : list [Path ] = []
143
+ for models_glob in models_globs :
144
+ for path in glob (models_glob ):
145
+ models .append (Path (path ))
108
146
postfix = "nbest" if args .nbest else "out"
109
147
output_zst = artifacts / f"{ input_zst .stem } .{ postfix } .zst"
110
148
vocab : Path = args .vocab
111
149
gpus : list [str ] = args .gpus .split (" " )
112
150
extra_marian_args : list [str ] = args .extra_marian_args
151
+ decoder : Decoder = args .decoder
152
+ is_nbest : bool = args .nbest
153
+ device : Device = args .device
113
154
114
155
# Do some light validation of the arguments.
115
156
assert input_zst .exists (), f"The input file exists: { input_zst } "
@@ -118,6 +159,7 @@ def main() -> None:
118
159
artifacts .mkdir ()
119
160
for gpu_index in gpus :
120
161
assert gpu_index .isdigit (), f'GPUs must be list of numbers: "{ gpu_index } "'
162
+ assert models , "There must be at least one model"
121
163
for model in models :
122
164
assert model .exists (), f"The model file exists { model } "
123
165
if extra_marian_args and extra_marian_args [0 ] != "--" :
@@ -136,6 +178,29 @@ def main() -> None:
136
178
pass
137
179
return
138
180
181
+ if decoder == Decoder .ctranslate2 :
182
+ translate_with_ctranslate2 (
183
+ input_zst = input_zst ,
184
+ artifacts = artifacts ,
185
+ extra_marian_args = extra_marian_args ,
186
+ models_glob = models_glob ,
187
+ is_nbest = is_nbest ,
188
+ vocab = [str (vocab )],
189
+ device = device .value ,
190
+ )
191
+ return
192
+
193
+ # The device flag is for use with CTranslate, but add some assertions here so that
194
+ # we can be consistent in usage.
195
+ if device == Device .cpu :
196
+ assert (
197
+ "--cpu-threads" in extra_marian_args
198
+ ), "Marian's cpu should be controlled with the flag --cpu-threads"
199
+ else :
200
+ assert (
201
+ "--cpu-threads" not in extra_marian_args
202
+ ), "Requested a GPU device, but --cpu-threads was provided"
203
+
139
204
# Run the training.
140
205
with tempfile .TemporaryDirectory () as temp_dir_str :
141
206
temp_dir = Path (temp_dir_str )
@@ -152,16 +217,26 @@ def main() -> None:
152
217
output = output_txt ,
153
218
gpus = gpus ,
154
219
workspace = args .workspace ,
155
- is_nbest = args . nbest ,
220
+ is_nbest = is_nbest ,
156
221
# Take off the initial "--"
157
222
extra_args = extra_marian_args [1 :],
158
223
)
159
- assert count_lines (input_txt ) == count_lines (
160
- output_txt
161
- ), "The input and output had the same number of lines"
162
224
163
225
compress (output_txt , destination = output_zst , remove = True , logger = logger )
164
226
227
+ input_count = count_lines (input_txt )
228
+ output_count = count_lines (output_zst )
229
+ if is_nbest :
230
+ beam_size = get_beam_size (extra_marian_args )
231
+ expected_output = input_count * beam_size
232
+ assert (
233
+ expected_output == output_count
234
+ ), f"The nbest output had { beam_size } x as many lines ({ expected_output } vs { output_count } )"
235
+ else :
236
+ assert (
237
+ input_count == output_count
238
+ ), f"The input ({ input_count } and output ({ output_count } ) had the same number of lines"
239
+
165
240
166
241
if __name__ == "__main__" :
167
242
main ()
0 commit comments