You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm using the ag_news dataset available from huggingface. I was trying to train the classifier with the Trainer class of the transformers library using the following code:
training_args=TrainingArguments(
output_dir='training_with_es',
learning_rate=2e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=5,
weight_decay=0.01,
evaluation_strategy='steps',
report_to=None, # don't report to wandb on default# required for early stoppingload_best_model_at_end=True,
eval_steps=100,
metric_for_best_model='f1',
)
trainer=Trainer(
model=prompt_model,
args=training_args,
train_dataset=train_dataloader.dataloader.dataset,
eval_dataset=valid_dataloader.dataloader.dataset,
tokenizer=None,
compute_metrics=compute_metrics,
optimizers=(optimizer, None),
)
The train_dataloader and valid_dataloader are instances of PromptDataLoader. prompt_model instead is an instance of PromptForClassification with plm a pretrained BertForMaskedLM.
But when I run trainer.train() I got the following error:
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
TypeError: PromptForClassification.forward() got an unexpected keyword argument 'labels'
I don't know whether I have made some mistakes or it's normal that doesn't work. Thanks in advance.
Here is the full code I used:
fromtqdmimporttqdmfromdatasetsimportload_datasetfromopenprompt.data_utilsimportInputExamplefromopenprompt.plms.mlmimportMLMTokenizerWrapperfromtransformersimportBertForMaskedLM, BertTokenizer, BertConfigfromopenprompt.plmsimportload_plmfromopenprompt.promptsimportManualVerbalizer, ManualTemplatefromopenpromptimportPromptDataLoaderfromopenpromptimportPromptForClassificationimporttorchfromtransformersimportTrainer, TrainingArgumentsfromtorch.optimimportAdamWfromtransformersimportEarlyStoppingCallbackimportevaluate# function used for classification evaluationdefcompute_metrics(eval_pred):
predictions, labels=eval_predpredictions=np.argmax(predictions, axis=1)
acc_score=accuracy.compute(predictions=predictions,
references=labels)["accuracy"]
f1_score=f1.compute(predictions=predictions,
references=labels,
average="micro")["f1"]
return {"accuracy": acc_score, "F1 score": f1_score}
# load the dataset raw_dataset=load_dataset('ag_news',
cache_dir="../datasets/.cache/huggingface_datasets")
raw_dataset['train'][0]
dataset= {}
forsplitin ['train', 'test']:
dataset[split] = []
foridx, dataintqdm(enumerate(raw_dataset[split])):
input_example=InputExample(text_a=data['text'],
label=int(data['label']),
guid=idx)
dataset[split].append(input_example)
print(dataset['train'][0])
# load the modelmodel, tokenizer, model_config, WrapperClass= (
BertForMaskedLM.from_pretrained('nlpaueb/legal-bert-base-uncased'),
BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased'),
BertConfig.from_pretrained('nlpaueb/legal-bert-base-uncased'),
MLMTokenizerWrapper
)
# define a Verbalizerverbalizer=ManualVerbalizer(tokenizer=tokenizer,
num_classes=4,
label_words=[['World'],
['Sports'],
['Business'],
['Sci/Tech']])
# define a Templatetemplate=ManualTemplate(tokenizer=tokenizer,
text='{"placeholder":"text_a"}. What topic is that? {"mask"}')
# view wrapped examplewrapped_example=template.wrap_one_example(dataset['train'][0])
print(wrapped_example)
train_dataloader=PromptDataLoader(dataset['train'],
template,
tokenizer=tokenizer,
tokenizer_wrapper_class=WrapperClass,
batch_size=64,
decoder_max_length=384,
max_seq_length=384,
shuffle=False,
teacher_forcing=False)
valid_dataloader=PromptDataLoader(dataset['test'],
template,
tokenizer=tokenizer,
tokenizer_wrapper_class=WrapperClass,
batch_size=64,
decoder_max_length=384,
max_seq_length=384,
shuffle=False,
teacher_forcing=False)
device=torch.device('cuda'iftorch.cuda.is_available() else'cpu')
prompt_model=PromptForClassification(plm=model,
template=template,
verbalizer=verbalizer,
freeze_plm=False)
prompt_model=prompt_model.to(device)
no_decay= ['bias', 'LayerNorm.weight']
# ===========================# training / testing section# ===========================# it's always good practice to set no decay to biase and LayerNorm parametersoptimizer_grouped_parameters= [
{'params': [pforn, pinprompt_model.named_parameters()
ifnotany(ndinnforndinno_decay)], 'weight_decay': 0.01},
{'params': [pforn, pinprompt_model.named_parameters()
ifany(ndinnforndinno_decay)], 'weight_decay': 0.0}
]
optimizer=AdamW(optimizer_grouped_parameters, lr=1e-4)
accuracy=evaluate.load("accuracy")
f1=evaluate.load("f1")
training_args=TrainingArguments(
output_dir='training_with_es',
learning_rate=2e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=5,
weight_decay=0.01,
evaluation_strategy='steps',
report_to=None, # don't report to wandb on default# required for early stoppingload_best_model_at_end=True,
eval_steps=100,
metric_for_best_model='f1',
)
trainer=Trainer(
model=prompt_model,
args=training_args,
train_dataset=train_dataloader.dataloader.dataset,
eval_dataset=valid_dataloader.dataloader.dataset,
tokenizer=None,
compute_metrics=compute_metrics,
optimizers=(optimizer, None),
)
trainer.train()
The text was updated successfully, but these errors were encountered:
Hi, I'm using the
ag_news
dataset available from huggingface. I was trying to train the classifier with the Trainer class of the transformers library using the following code:The
train_dataloader
andvalid_dataloader
are instances ofPromptDataLoader
.prompt_model
instead is an instance ofPromptForClassification
withplm
a pretrained BertForMaskedLM.But when I run
trainer.train()
I got the following error:I don't know whether I have made some mistakes or it's normal that doesn't work. Thanks in advance.
Here is the full code I used:
The text was updated successfully, but these errors were encountered: