Skip to content

Codebase for TFIDF based context word identification to train a GCN classifier head on top of BERT for improved classification performance.

License

Notifications You must be signed in to change notification settings

imbesat-rizvi/ContextGNNBERT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ContextGNNBERT

Codebase of our work on improving performance of language models such as BERT by infusing domain context in the top classifier layers. While the weights of BERT self-attention layers can be trained or finetuned on the specific domain/dataset, recent efforts have been made to combine additional information sources such as external knowledge bases or links between documents. With a similar motivation, this work builds on identifying context words, with reference to a domain or dataset, using percentile thresholding of TF-IDF scores for words in a given text, and replacing the classifier head of language models with a graph convolutional network (GCN) applied over the context words and the [CLS] token. Our preliminary experiments are encouraging and shows that the approach further boosts the downstream classification performance even when working with pre-trained domain specific language models e.g. SCIBERT on scientific entailment (scitail) dataset.

To fit and save a ContextMasker such as TFIDFContextMasker on a large dataset, say wikipedia, so that it can be used later on repeatedly, execute:

python3 -m ContextMasker.TFIDFContextMasker \
    --dataset_name wikipedia:20220301.en \
    --corpus_cols text \
    --tokenizer bert-base-uncased \
    --save_path data/TFIDFContextMasker.jb

The context masker can then be used along with dataset specific training and evaluation on the train, validation and test set, execute:

python3 main.py \
    --dataset_name scitail:snli_format \
    --label_col gold_label \
    --pos_label entailment \
    --input_text_cols sentence1 sentence2 \
    --context_masker_load_path data/TFIDFContextMasker.jb \
    --context_mask_fn_kwargs '{"percentile_cutoff": 50}' \
    --truncation_strategy longest_first \
    --classifier_net GATConv \
    --num_train_epochs 50 \
    --batch_size 12 \
    --optimizer_name AdamW \
    --optimizer_kwargs '{"lr": 1e-3, "eps": 1e-8}' \
    --encoder_optimizer_kwargs '{"lr": 1e-5}'

Alternatively, to run context masker fit on the specific dataset at hand, and to train and report evaluation on the train, validation and test set, execute:

python3 main.py \
    --dataset_name scitail:snli_format \
    --label_col gold_label \
    --pos_label entailment \
    --input_text_cols sentence1 sentence2 \
    --context_corpus_splits train \
    --context_masker TFIDFContextMasker \
    --context_mask_fn_kwargs '{"percentile_cutoff": 50}' \
    --truncation_strategy longest_first \
    --classifier_net GATConv \
    --num_train_epochs 50 \
    --batch_size 12 \
    --optimizer_name AdamW \
    --optimizer_kwargs '{"lr": 1e-3, "eps": 1e-8}' \
    --encoder_optimizer_kwargs '{"lr": 1e-5}'

or to run the script without wandb experiment tracking, execute as:

WANDB_MODE=disabled python3 main.py ...

About

Codebase for TFIDF based context word identification to train a GCN classifier head on top of BERT for improved classification performance.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages