Skip to content

Commit

Permalink
Merge pull request #12 from HMJiangGatech/semi
Browse files Browse the repository at this point in the history
Add Semi-supervised code + Fix bugs
  • Loading branch information
cliang1453 authored Jun 2, 2021
2 parents 309a64d + a21b1ce commit 32f2698
Show file tree
Hide file tree
Showing 17 changed files with 453 additions and 20 deletions.
9 changes: 7 additions & 2 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def convert_examples_to_features(
hp_label_ids = []
for word, label, hp_label in zip(example.words, example.labels, example.hp_labels):
word_tokens = tokenizer.tokenize(word)
if(len(word_tokens) == 0):
continue
tokens.extend(word_tokens)
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
label_ids.extend([label] + [pad_token_label_id] * (len(word_tokens) - 1))
Expand Down Expand Up @@ -216,7 +218,7 @@ def convert_examples_to_features(
return features


def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode, remove_labels=False):

if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
Expand Down Expand Up @@ -267,6 +269,9 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)
all_full_label_ids = torch.tensor([f.full_label_ids for f in features], dtype=torch.long)
all_hp_label_ids = torch.tensor([f.hp_label_ids for f in features], dtype=torch.long)
if remove_labels:
all_full_label_ids.fill_(pad_token_label_id)
all_hp_label_ids.fill_(pad_token_label_id)
all_ids = torch.tensor([f for f in range(len(features))], dtype=torch.long)

dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_full_label_ids, all_hp_label_ids, all_ids)
Expand Down Expand Up @@ -356,4 +361,4 @@ def get_chunks(seq, tags):


if __name__ == '__main__':
save(args)
save(args)
1 change: 1 addition & 0 deletions dataset/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*/cached_*
1 change: 1 addition & 0 deletions dataset/BC5CDR-chem/dev.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dataset/BC5CDR-chem/tag_to_id.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"O": 0, "B-Chemical": 1, "I-Chemical": 2}
1 change: 1 addition & 0 deletions dataset/BC5CDR-chem/test.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dataset/BC5CDR-chem/train.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dataset/BC5CDR-chem/weak.json

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, best, mode, pre
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
#if args.n_gpu > 1:
# model = torch.nn.DataParallel(model)
#model.to(args.device)

logger.info("***** Running evaluation %s *****", prefix)
if verbose:
Expand Down
13 changes: 12 additions & 1 deletion model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def multi_source_label_refine(args, hp_labels, combined_labels, pred_labels, pad
label_mask = (pred_labels.max(dim=-1)[0]>_threshold)
if args.self_training_hp_label < 5:
label_mask = label_mask & (combined_labels!=pad_token_label_id)
elif 6 <= args.self_training_hp_label <= 7:
elif 6 <= args.self_training_hp_label < 7:
_threshold = args.self_training_hp_label%1
_confidence = pred_labels.max(dim=-1)[0]
for i in range(1,pred_labels.shape[0]):
Expand All @@ -97,6 +97,17 @@ def multi_source_label_refine(args, hp_labels, combined_labels, pred_labels, pad
_distantlabel = combined_labels[i,j]
pred_labels[i,j] *= 0
pred_labels[i,j,_distantlabel] = 1
elif 7 <= args.self_training_hp_label < 9:
_threshold = args.self_training_hp_label%1
label_mask = (pred_labels.max(dim=-1)[0]>_threshold)
if args.self_training_hp_label < 8:
label_mask = label_mask & (combined_labels!=pad_token_label_id)
# overwrite by hp_labels
for i in range(0,pred_labels.shape[2]):
_labeli = [0]*pred_labels.shape[2]
_labeli[i] = 1
_labeli = torch.tensor(_labeli).to(pred_labels)
pred_labels[hp_labels==i] = _labeli
else:
raise NotImplementedError('error')

Expand Down
53 changes: 53 additions & 0 deletions modeling_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from transformers import BertPreTrainedModel,BertForTokenClassification
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, KLDivLoss

class BERTForTokenClassification_v2(BertForTokenClassification):

def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, inputs_embeds=None, labels=None, label_mask=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

outputs = (logits,sequence_output) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:

# Only keep active parts of the loss
if attention_mask is not None or label_mask is not None:
active_loss = True
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
if label_mask is not None:
active_loss = active_loss & label_mask.view(-1)
active_logits = logits.view(-1, self.num_labels)[active_loss]


if labels.shape == logits.shape:
loss_fct = KLDivLoss()
if attention_mask is not None or label_mask is not None:
active_labels = labels.view(-1, self.num_labels)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits, labels)
else:
loss_fct = CrossEntropyLoss()
if attention_mask is not None or label_mask is not None:
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))


outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
45 changes: 31 additions & 14 deletions run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)

from modeling_roberta import RobertaForTokenClassification_v2
from modeling_bert import BERTForTokenClassification_v2
from data_utils import load_and_cache_examples, get_labels
from model_utils import mt_update, get_mt_loss, opt_grad
from eval import evaluate
Expand All @@ -73,7 +74,8 @@
)

MODEL_CLASSES = {
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"bert": (BertConfig, BERTForTokenClassification_v2, BertTokenizer),
"biobert": (BertConfig, BERTForTokenClassification_v2, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification_v2, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
Expand Down Expand Up @@ -162,15 +164,18 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
try:
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except:
logger.warning(f"Unable to recover training step from {args.model_name_or_path}")

tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
Expand Down Expand Up @@ -198,9 +203,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
#inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[4]}
if args.model_type != "distilbert":
inputs["token_type_ids"] = (
batch[2] if args.model_type in ["bert", "xlnet"] else None
batch[2] if args.model_type in ["bert", "biobert", "xlnet"] else None
) # XLM and RoBERTa don"t use segment_ids

# import ipdb; ipdb.set_trace()
outputs = model(**inputs)
loss, logits, final_embeds = outputs[0], outputs[1], outputs[2] # model outputs are always tuple in pytorch-transformers (see doc)
mt_loss, vat_loss = 0, 0
Expand Down Expand Up @@ -236,7 +242,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):

if args.model_type in ["roberta", "camembert", "xlmroberta"]:
word_embed = model.roberta.get_input_embeddings()
elif args.model_type == "bert":
elif args.model_type in ["bert", "biobert"]:
word_embed = model.bert.get_input_embeddings()
elif args.model_type == "distilbert":
word_embed = model.distilbert.get_input_embeddings()
Expand All @@ -251,7 +257,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
vat_inputs = {"inputs_embeds": vat_embeds, "attention_mask": batch[1], "labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"] = (
batch[2] if args.model_type in ["bert", "xlnet"] else None
batch[2] if args.model_type in ["bert", "biobert", "xlnet"] else None
) # XLM and RoBERTa don"t use segment_ids

vat_outputs = model(**vat_inputs)
Expand All @@ -275,7 +281,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
vat_inputs = {"inputs_embeds": vat_embeds, "attention_mask": batch[1], "labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"] = (
batch[2] if args.model_type in ["bert", "xlnet"] else None
batch[2] if args.model_type in ["bert", "biobert", "xlnet"] else None
) # XLM and RoBERTa don"t use segment_ids

vat_outputs = model(**vat_inputs)
Expand Down Expand Up @@ -514,6 +520,12 @@ def main():
parser.add_argument('--vat_loss_type', default="logits", type=str, help="subject to measure model difference, choices = [embeds, logits(default)].")


# Use data from weak.json
parser.add_argument('--load_weak', action="store_true", help = 'Load data from weak.json.')
parser.add_argument('--remove_labels_from_weak', action="store_true", help = 'Use data from weak.json, and remove their labels for semi-supervised learning')
parser.add_argument('--rep_train_against_weak', type = int, default = 1, help = 'Upsampling training data again weak data. Default: 1')


args = parser.parse_args()

if (
Expand Down Expand Up @@ -609,6 +621,11 @@ def main():
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train")
# import ipdb; ipdb.set_trace()
if args.load_weak:
weak_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="weak", remove_labels=args.remove_labels_from_weak)
train_dataset = torch.utils.data.ConcatDataset([train_dataset]*args.rep_train_against_weak + [weak_dataset,])

global_step, tr_loss, best_dev, best_test = train(args, train_dataset, model, tokenizer, labels, pad_token_label_id)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

Expand Down
13 changes: 12 additions & 1 deletion run_self_training_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_linear_schedule_with_warmup,
)

from modeling_bert import BERTForTokenClassification_v2
from modeling_roberta import RobertaForTokenClassification_v2
from data_utils import load_and_cache_examples, get_labels
from model_utils import multi_source_label_refine, soft_frequency, mt_update, get_mt_loss, opt_grad
Expand All @@ -73,7 +74,7 @@
)

MODEL_CLASSES = {
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"bert": (BertConfig, BERTForTokenClassification_v2, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification_v2, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
Expand Down Expand Up @@ -570,6 +571,11 @@ def main():
parser.add_argument('--self_training_hp_label', type = float, default = 0, help = 'use high precision label.')
parser.add_argument('--self_training_ensemble_label', type = int, default = 0, help = 'use ensemble label.')

# Use data from weak.json
parser.add_argument('--load_weak', action="store_true", help = 'Load data from weak.json.')
parser.add_argument('--remove_labels_from_weak', action="store_true", help = 'Use data from weak.json, and remove their labels for semi-supervised learning')
parser.add_argument('--rep_train_against_weak', type = int, default = 1, help = 'Upsampling training data again weak data. Default: 1')

args = parser.parse_args()

if (
Expand Down Expand Up @@ -657,6 +663,11 @@ def main():
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train")
# import ipdb; ipdb.set_trace()
if args.load_weak:
weak_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="weak", remove_labels=args.remove_labels_from_weak)
train_dataset = torch.utils.data.ConcatDataset([train_dataset]*args.rep_train_against_weak + [weak_dataset,])

model, global_step, tr_loss, best_dev, best_test = train(args, train_dataset, model_class, config, tokenizer, labels, pad_token_label_id)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

Expand Down
65 changes: 65 additions & 0 deletions semi_script/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
## 1. Data Format

Strongly labeled:
`dev.txt`
`test.txt`
`train.txt`

Weakly labeled:
`weak.txt`

Transform into json, e.g., see `dataset/BC5CDR-chem/turn.py`

## 2. Train Baseline

```bash
sh ./semi_script/bc5cdr_chem_basline.sh GPUTIDS
```
where GPUTIDS are the ids of gpus, e.g., `sh ./semi_script/bc5cdr_chem_basline.sh 0,1,2,3`

## 3. Semi Supervised Learning

**Mean Teacher**
```bash
sh ./semi_script/bc5cdr_chem_mt.sh GPUTIDS
```
additional parameters:
1. change `MODEL_NAME` to the baseline model
2. `--mt 1` for enabling mean teacher
3. `--load_weak` and `--remove_labels_from_weak ` for loading data from weak.json and remove their labels.
4. `--rep_train_against_weak N` for upsampling strongly labeled data by `N` times.

Other parameters
```
parser.add_argument('--mt', type = int, default = 0, help = 'mean teacher.')
parser.add_argument('--mt_updatefreq', type=int, default=1, help = 'mean teacher update frequency')
parser.add_argument('--mt_class', type=str, default="kl", help = 'mean teacher class, choices:[smart, prob, logit, kl(default), distill].')
parser.add_argument('--mt_lambda', type=float, default=1, help= "trade off parameter of the consistent loss.")
parser.add_argument('--mt_rampup', type=int, default=300, help="rampup iteration.")
parser.add_argument('--mt_alpha1', default=0.99, type=float, help="moving average parameter of mean teacher (for the exponential moving average).")
parser.add_argument('--mt_alpha2', default=0.995, type=float, help="moving average parameter of mean teacher (for the exponential moving average).")
parser.add_argument('--mt_beta', default=10, type=float, help="coefficient of mt_loss term.")
parser.add_argument('--mt_avg', default="exponential", type=str, help="moving average method, choices:[exponentail(default), simple, double_ema].")
parser.add_argument('--mt_loss_type', default="logits", type=str, help="subject to measure model difference, choices:[embeds, logits(default)].")
```


**VAT**
```bash
sh ./semi_script/bc5cdr_chem_vat.sh GPUTIDS
```
additional parameters:
1. change `MODEL_NAME` to the baseline model
2. `--vat 1` for enabling mean teacher
3. `--load_weak` and `--remove_labels_from_weak ` for loading data from weak.json and remove their labels.
4. `--rep_train_against_weak N` for upsampling strongly labeled data by `N` times.

Other parameters
```
# virtual adversarial training
parser.add_argument('--vat', type = int, default = 0, help = 'virtual adversarial training.')
parser.add_argument('--vat_eps', type = float, default = 1e-3, help = 'perturbation size for virtual adversarial training.')
parser.add_argument('--vat_lambda', type = float, default = 1, help = 'trade off parameter for virtual adversarial training.')
parser.add_argument('--vat_beta', type = float, default = 1, help = 'coefficient of the virtual adversarial training loss term.')
parser.add_argument('--vat_loss_type', default="logits", type=str, help="subject to measure model difference, choices = [embeds, logits(default)].")
```
Loading

0 comments on commit 32f2698

Please sign in to comment.