diff --git a/bert_pytorch/__main__.py b/bert_pytorch/__main__.py index d4193f2..725a42a 100644 --- a/bert_pytorch/__main__.py +++ b/bert_pytorch/__main__.py @@ -14,11 +14,13 @@ def train(): parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set") parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with bert-vocab") parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model") + parser.add_argument("-log", "--log_dir", type=str, default=None, help="dir for saving log") parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model") parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers") parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads") - parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence len") + parser.add_argument("-s", "--seq_len", type=int, default=32, help="maximum sequence len") + parser.add_argument("-d", "--dropout", type=float, default=0.01, help="dropout rate") parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size") parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs") @@ -29,9 +31,10 @@ def train(): parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus") parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids") parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false") + parser.add_argument("--load_model", type=str, default=None, help="Load model weights from saved model") - parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam") - parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam") + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate of adam") + parser.add_argument("--adam_weight_decay", type=float, default=0.00, help="weight_decay of adam") parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value") parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value") @@ -55,15 +58,20 @@ def train(): if test_dataset is not None else None print("Building BERT model") - bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) + bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads, dropout=args.dropout) print("Creating BERT Trainer") trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, - with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq) + with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, log_dir=args.log_dir) print("Training Start") - for epoch in range(args.epochs): + if args.load_model: + e = trainer.load(args.load_model) + else: + e = 0 + + for epoch in range(e, args.epochs): trainer.train(epoch) trainer.save(epoch, args.output_path) diff --git a/bert_pytorch/dataset/vocab.py b/bert_pytorch/dataset/vocab.py index f7346a7..f74a23c 100644 --- a/bert_pytorch/dataset/vocab.py +++ b/bert_pytorch/dataset/vocab.py @@ -117,10 +117,10 @@ def save_vocab(self, vocab_path): # Building Vocab with text files class WordVocab(Vocab): - def __init__(self, texts, max_size=None, min_freq=1): + def __init__(self, texts, total=None ,max_size=None, min_freq=1): print("Building Vocab") counter = Counter() - for line in tqdm.tqdm(texts): + for line in tqdm.tqdm(texts, total=total): if isinstance(line, list): words = line else: @@ -166,7 +166,6 @@ def load_vocab(vocab_path: str) -> 'WordVocab': with open(vocab_path, "rb") as f: return pickle.load(f) - def build(): import argparse @@ -178,8 +177,11 @@ def build(): parser.add_argument("-m", "--min_freq", type=int, default=1) args = parser.parse_args() + + total = sum(1 for line in open(args.corpus_path, "r", encoding=args.encoding)) + with open(args.corpus_path, "r", encoding=args.encoding) as f: - vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq) + vocab = WordVocab(f, total=total, max_size=args.vocab_size, min_freq=args.min_freq) print("VOCAB SIZE:", len(vocab)) vocab.save_vocab(args.output_path) diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index 0b882dd..df96bdf 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -7,6 +7,7 @@ from .optim_schedule import ScheduledOptim import tqdm +import os class BERTTrainer: @@ -23,7 +24,7 @@ class BERTTrainer: def __init__(self, bert: BERT, vocab_size: int, train_dataloader: DataLoader, test_dataloader: DataLoader = None, lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, - with_cuda: bool = True, cuda_devices=None, log_freq: int = 10): + with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, log_dir: str = None): """ :param bert: BERT model which you want to train :param vocab_size: total word vocab size @@ -59,9 +60,10 @@ def __init__(self, bert: BERT, vocab_size: int, self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) # Using Negative Log Likelihood Loss function for predicting the masked_token - self.criterion = nn.NLLLoss(ignore_index=0) + self.criterion = nn.NLLLoss() # nn.NLLLoss(ignore_index=0) self.log_freq = log_freq + self.log_dir = log_dir print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) @@ -132,9 +134,13 @@ def iteration(self, epoch, data_loader, train=True): if i % self.log_freq == 0: data_iter.write(str(post_fix)) - - print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", - total_correct * 100.0 / total_element) + to_print = "EP{}_{}, avg_loss={}, total_acc={}\n".format(epoch, str_code, avg_loss / len(data_iter),\ + total_correct * 100.0 / total_element) + print(to_print) + if self.log_dir: + os.mkdir(os.path.dirname(self.log_dir), exist_ok=True) + with open(self.log_dir, 'a', encoding='utf8') as log_file: + log_file.write(to_print) def save(self, epoch, file_path="output/bert_trained.model"): """ @@ -145,7 +151,24 @@ def save(self, epoch, file_path="output/bert_trained.model"): :return: final_output_path """ output_path = file_path + ".ep%d" % epoch - torch.save(self.bert.cpu(), output_path) + torch.save({'epoch': epoch, + 'model_state_dict': self.bert.cpu().state_dict(), + 'optimzer_state_dict': self.optim.state_dict(), + }, output_path) self.bert.to(self.device) print("EP:%d Model Saved on:" % epoch, output_path) return output_path + + def load(self, model_path): + """ + Load BERT model from saved model + """ + self.bert = torch.load(model_path) + + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optim.load_state_dict(checkpoint['optimizer_state_dict']) + epoch = checkpoint['epoch'] + + self.bert.eval() + + return epoch