Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use transformers Trainer when training #292

Open
CaptWake opened this issue Jul 28, 2023 · 0 comments
Open

use transformers Trainer when training #292

CaptWake opened this issue Jul 28, 2023 · 0 comments

Comments

@CaptWake
Copy link

CaptWake commented Jul 28, 2023

Hi, I'm using the ag_news dataset available from huggingface. I was trying to train the classifier with the Trainer class of the transformers library using the following code:

training_args = TrainingArguments(
    output_dir='training_with_es',
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy='steps',
    report_to=None, # don't report to wandb on default
    # required for early stopping
    load_best_model_at_end = True,
    eval_steps = 100,
    metric_for_best_model = 'f1',
)

trainer = Trainer(
    model=prompt_model,
    args=training_args,
    train_dataset=train_dataloader.dataloader.dataset,
    eval_dataset=valid_dataloader.dataloader.dataset,
    tokenizer=None,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, None),
)

The train_dataloader and valid_dataloader are instances of PromptDataLoader. prompt_model instead is an instance of PromptForClassification with plm a pretrained BertForMaskedLM.

But when I run trainer.train() I got the following error:

TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: PromptForClassification.forward() got an unexpected keyword argument 'labels'

I don't know whether I have made some mistakes or it's normal that doesn't work. Thanks in advance.
Here is the full code I used:

from tqdm import tqdm
from datasets import load_dataset
from openprompt.data_utils import InputExample
from openprompt.plms.mlm import MLMTokenizerWrapper
from transformers import BertForMaskedLM, BertTokenizer, BertConfig
from openprompt.plms import load_plm
from openprompt.prompts import ManualVerbalizer, ManualTemplate
from openprompt import PromptDataLoader
from openprompt import PromptForClassification
import torch
from transformers import Trainer, TrainingArguments
from torch.optim import AdamW
from transformers import EarlyStoppingCallback
import evaluate

# function used for classification evaluation
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    acc_score = accuracy.compute(predictions=predictions, 
                                 references=labels)["accuracy"]
    f1_score = f1.compute(predictions=predictions, 
                          references=labels, 
                          average="micro")["f1"]
    return {"accuracy": acc_score, "F1 score": f1_score}

# load the dataset 
raw_dataset = load_dataset('ag_news', 
                           cache_dir="../datasets/.cache/huggingface_datasets")
raw_dataset['train'][0]

dataset = {}
for split in ['train', 'test']:
    dataset[split] = []
    for idx, data in tqdm(enumerate(raw_dataset[split])):
        input_example = InputExample(text_a = data['text'], 
                                     label=int(data['label']), 
                                     guid=idx)
        dataset[split].append(input_example)
print(dataset['train'][0])

# load the model
model, tokenizer, model_config, WrapperClass = (
    BertForMaskedLM.from_pretrained('nlpaueb/legal-bert-base-uncased'),
    BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased'),
    BertConfig.from_pretrained('nlpaueb/legal-bert-base-uncased'),
    MLMTokenizerWrapper
)

# define a Verbalizer
verbalizer = ManualVerbalizer(tokenizer=tokenizer, 
                              num_classes=4, 
                              label_words=[['World'], 
                                           ['Sports'], 
                                           ['Business'], 
                                           ['Sci/Tech']])

# define a Template
template = ManualTemplate(tokenizer=tokenizer, 
                          text='{"placeholder":"text_a"}. What topic is that? {"mask"}')

# view wrapped example
wrapped_example = template.wrap_one_example(dataset['train'][0])
print(wrapped_example)

train_dataloader = PromptDataLoader(dataset['train'], 
                                    template, 
                                    tokenizer=tokenizer, 
                                    tokenizer_wrapper_class=WrapperClass, 
                                    batch_size=64,
                                    decoder_max_length=384,
                                    max_seq_length=384, 
                                    shuffle=False, 
                                    teacher_forcing=False)

valid_dataloader = PromptDataLoader(dataset['test'], 
                                    template, 
                                    tokenizer=tokenizer, 
                                    tokenizer_wrapper_class=WrapperClass, 
                                    batch_size=64,
                                    decoder_max_length=384,
                                    max_seq_length=384, 
                                    shuffle=False, 
                                    teacher_forcing=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

prompt_model = PromptForClassification(plm=model, 
                                       template=template, 
                                       verbalizer=verbalizer, 
                                       freeze_plm=False)
prompt_model=prompt_model.to(device)

no_decay = ['bias', 'LayerNorm.weight']

# ===========================
# training / testing section
# ===========================

# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters = [
    {'params': [p for n, p in prompt_model.named_parameters() 
                if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.named_parameters() 
                if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)
 
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

training_args = TrainingArguments(
    output_dir='training_with_es',
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy='steps',
    report_to=None, # don't report to wandb on default
    # required for early stopping
    load_best_model_at_end = True,
    eval_steps = 100,
    metric_for_best_model = 'f1',
)

trainer = Trainer(
    model=prompt_model,
    args=training_args,
    train_dataset=train_dataloader.dataloader.dataset,
    eval_dataset=valid_dataloader.dataloader.dataset,
    tokenizer=None,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, None),
)

trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant