diff --git a/train_flert_model.py b/train_flert_model.py index 601f5dd..4447cd4 100644 --- a/train_flert_model.py +++ b/train_flert_model.py @@ -6,6 +6,7 @@ from flair.embeddings import TransformerWordEmbeddings from flair.models import SequenceTagger from flair.trainers import ModelTrainer +from torch.optim.lr_scheduler import OneCycleLR if __name__ == "__main__": @@ -74,8 +75,7 @@ output_folder = f"flert-{args.dataset}-{hf_model}-{seed}" # train with XLM parameters (AdamW, 20 epochs, small LR) - from torch.optim.lr_scheduler import OneCycleLR - + trainer.train( output_folder, learning_rate=5.0e-5,