diff --git a/README.md b/README.md index fb159d1..5ab264a 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ bert-dataset -d data/corpus.small -v data/corpus.small.vocab -o data/dataset.sma ### 3. Train your own BERT model ```shell -bert -d data/dataset.small -v data/corpus.small.vocab -o output/ +bert -d data/dataset.small -v data/corpus.small.vocab -o output/bert.model ``` ## Language Model Pre-training diff --git a/bert_pytorch/train.py b/bert_pytorch/train.py index 618ce53..d6887de 100644 --- a/bert_pytorch/train.py +++ b/bert_pytorch/train.py @@ -13,7 +13,7 @@ def train(): parser.add_argument("-d", "--train_dataset", required=True, type=str) parser.add_argument("-t", "--test_dataset", type=str, default=None) parser.add_argument("-v", "--vocab_path", required=True, type=str) - parser.add_argument("-o", "--output_dir", required=True, type=str) + parser.add_argument("-o", "--output_path", required=True, type=str) parser.add_argument("-hs", "--hidden", type=int, default=256) parser.add_argument("-n", "--layers", type=int, default=8) @@ -61,7 +61,7 @@ def train(): print("Training Start") for epoch in range(args.epochs): trainer.train(epoch) - trainer.save(args.output_dir, epoch) + trainer.save(args.output_path, epoch) if test_data_loader is not None: trainer.test(epoch)