-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
44 lines (40 loc) · 1.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import argparse
from pathlib import Path
from src.models.t5.trainer import T5Trainer
def main():
parser = argparse.ArgumentParser(
description="Train T5 Model"
)
parser.add_argument(
"-c",
"--config_file",
type=str,
help="Path to the run_config file for current run",
required=True
)
parser.add_argument(
"-m",
"--model_type",
type=str,
help="Specify type of model (t5 or other)",
required=True
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
help="Output directory path.",
required=True
)
args = parser.parse_args()
if(args.model_type == "t5"):
trainer = T5Trainer(
run_config_file = Path(args.config_file),
output_dir = Path(args.output_dir)
)
else:
raise ValueError(f"Unsupported model type: {args.model_type}")
# Train the model
trained_model = trainer.train()
if __name__ == "__main__":
main()