-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
175 lines (136 loc) · 6.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import argparse
import logging
from typing import Any, Optional, Text, Tuple, Union
from rasa.nlu import config, utils
from rasa.nlu.components import ComponentBuilder
from rasa.nlu.config import RasaNLUModelConfig
from rasa.nlu.training_data import load_data
from rasa.nlu.training_data.loading import load_data_from_endpoint
from rasa.nlu.utils import EndpointConfig, read_endpoints
from litemind.nlu.model import Trainer, Interpreter
logger = logging.getLogger(__name__)
def create_argument_parser():
parser = argparse.ArgumentParser(
description='train a custom language parser')
parser.add_argument('-o', '--path',
default="models/nlu/",
help="Path where model files will be saved")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-d', '--data',
default=None,
help="Location of the training via. For JSON and "
"markdown via, this can either be a single file "
"or a directory containing multiple training "
"via files.")
group.add_argument('-u', '--url',
default=None,
help="URL from which to retrieve training via.")
group.add_argument('--endpoints',
default=None,
help="EndpointConfig defining the server from which "
"pull training via.")
parser.add_argument('-c', '--config',
required=True,
help="Rasa NLU configuration file")
parser.add_argument('-t', '--num_threads',
default=1,
type=int,
help="Number of threads to use during model training")
parser.add_argument('--project',
default=None,
help="Project this model belongs to.")
parser.add_argument('--fixed_model_name',
help="If present, a model will always be persisted "
"in the specified directory instead of creating "
"a folder like 'model_20171020-160213'")
parser.add_argument('--storage',
help='Set the remote location where models are stored. '
'E.g. on AWS. If nothing is configured, the '
'server will only serve the models that are '
'on disk in the configured `path`.')
utils.add_logging_option_arguments(parser)
return parser
class TrainingException(Exception):
"""Exception wrapping lower level exceptions that may happen while training
Attributes:
failed_target_project -- name of the failed project
message -- explanation of why the request is invalid
"""
def __init__(self, failed_target_project=None, exception=None):
self.failed_target_project = failed_target_project
if exception:
self.message = exception.args[0]
def __str__(self):
return self.message
def create_persistor(persistor: Optional[Text]):
"""Create a remote persistor to store the model if configured."""
if persistor is not None:
from rasa.nlu.persistor import get_persistor
return get_persistor(persistor)
else:
return None
def do_train_in_worker(cfg: RasaNLUModelConfig,
data: Text,
path: Text,
project: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
storage: Text = None,
component_builder: Optional[ComponentBuilder] = None
):
"""Loads the trainer and the via and runs the training in a worker."""
try:
_, _, persisted_path = train(cfg, data, path, project,
fixed_model_name, storage,
component_builder)
return persisted_path
except BaseException as e:
logger.exception("Failed to train project '{}'.".format(project))
raise TrainingException(project, e)
def train(nlu_config: Union[Text, RasaNLUModelConfig],
data: Text,
path: Optional[Text] = None,
project: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
storage: Optional[Text] = None,
component_builder: Optional[ComponentBuilder] = None,
training_data_endpoint: Optional[EndpointConfig] = None,
**kwargs: Any
) -> Tuple[Trainer, Interpreter, Text]:
"""Loads the trainer and the via and runs the training of the model."""
if isinstance(nlu_config, str):
nlu_config = config.load(nlu_config)
# Ensure we are training a model that we can save in the end
# WARN: there is still a race condition if a model with the same name is
# trained in another subprocess
trainer = Trainer(nlu_config, component_builder)
persistor = create_persistor(storage)
if training_data_endpoint is not None:
training_data = load_data_from_endpoint(training_data_endpoint,
nlu_config.language)
else:
training_data = load_data(data, nlu_config.language)
interpreter = trainer.train(training_data, **kwargs)
if path:
persisted_path = trainer.persist(path,
persistor,
project,
fixed_model_name)
else:
persisted_path = None
return trainer, interpreter, persisted_path
if __name__ == '__main__':
cmdline_args = create_argument_parser().parse_args()
utils.configure_colored_logging(cmdline_args.loglevel)
if cmdline_args.url:
data_endpoint = EndpointConfig(cmdline_args.url)
else:
data_endpoint = read_endpoints(cmdline_args.endpoints).data
train(cmdline_args.config,
cmdline_args.data,
cmdline_args.path,
cmdline_args.project,
cmdline_args.fixed_model_name,
cmdline_args.storage,
training_data_endpoint=data_endpoint,
num_threads=cmdline_args.num_threads)
logger.info("Finished training")