-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from HMJiangGatech/semi
Add Semi-supervised code + Fix bugs
- Loading branch information
Showing
17 changed files
with
453 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*/cached_* |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"O": 0, "B-Chemical": 1, "I-Chemical": 2} |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from transformers import BertPreTrainedModel,BertForTokenClassification | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn import CrossEntropyLoss, KLDivLoss | ||
|
||
class BERTForTokenClassification_v2(BertForTokenClassification): | ||
|
||
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, | ||
position_ids=None, head_mask=None, inputs_embeds=None, labels=None, label_mask=None): | ||
|
||
outputs = self.bert(input_ids, | ||
attention_mask=attention_mask, | ||
token_type_ids=token_type_ids, | ||
position_ids=position_ids, | ||
head_mask=head_mask, | ||
inputs_embeds=inputs_embeds) | ||
|
||
sequence_output = outputs[0] | ||
|
||
sequence_output = self.dropout(sequence_output) | ||
logits = self.classifier(sequence_output) | ||
|
||
outputs = (logits,sequence_output) + outputs[2:] # add hidden states and attention if they are here | ||
if labels is not None: | ||
|
||
# Only keep active parts of the loss | ||
if attention_mask is not None or label_mask is not None: | ||
active_loss = True | ||
if attention_mask is not None: | ||
active_loss = attention_mask.view(-1) == 1 | ||
if label_mask is not None: | ||
active_loss = active_loss & label_mask.view(-1) | ||
active_logits = logits.view(-1, self.num_labels)[active_loss] | ||
|
||
|
||
if labels.shape == logits.shape: | ||
loss_fct = KLDivLoss() | ||
if attention_mask is not None or label_mask is not None: | ||
active_labels = labels.view(-1, self.num_labels)[active_loss] | ||
loss = loss_fct(active_logits, active_labels) | ||
else: | ||
loss = loss_fct(logits, labels) | ||
else: | ||
loss_fct = CrossEntropyLoss() | ||
if attention_mask is not None or label_mask is not None: | ||
active_labels = labels.view(-1)[active_loss] | ||
loss = loss_fct(active_logits, active_labels) | ||
else: | ||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | ||
|
||
|
||
outputs = (loss,) + outputs | ||
return outputs # (loss), scores, (hidden_states), (attentions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
## 1. Data Format | ||
|
||
Strongly labeled: | ||
`dev.txt` | ||
`test.txt` | ||
`train.txt` | ||
|
||
Weakly labeled: | ||
`weak.txt` | ||
|
||
Transform into json, e.g., see `dataset/BC5CDR-chem/turn.py` | ||
|
||
## 2. Train Baseline | ||
|
||
```bash | ||
sh ./semi_script/bc5cdr_chem_basline.sh GPUTIDS | ||
``` | ||
where GPUTIDS are the ids of gpus, e.g., `sh ./semi_script/bc5cdr_chem_basline.sh 0,1,2,3` | ||
|
||
## 3. Semi Supervised Learning | ||
|
||
**Mean Teacher** | ||
```bash | ||
sh ./semi_script/bc5cdr_chem_mt.sh GPUTIDS | ||
``` | ||
additional parameters: | ||
1. change `MODEL_NAME` to the baseline model | ||
2. `--mt 1` for enabling mean teacher | ||
3. `--load_weak` and `--remove_labels_from_weak ` for loading data from weak.json and remove their labels. | ||
4. `--rep_train_against_weak N` for upsampling strongly labeled data by `N` times. | ||
|
||
Other parameters | ||
``` | ||
parser.add_argument('--mt', type = int, default = 0, help = 'mean teacher.') | ||
parser.add_argument('--mt_updatefreq', type=int, default=1, help = 'mean teacher update frequency') | ||
parser.add_argument('--mt_class', type=str, default="kl", help = 'mean teacher class, choices:[smart, prob, logit, kl(default), distill].') | ||
parser.add_argument('--mt_lambda', type=float, default=1, help= "trade off parameter of the consistent loss.") | ||
parser.add_argument('--mt_rampup', type=int, default=300, help="rampup iteration.") | ||
parser.add_argument('--mt_alpha1', default=0.99, type=float, help="moving average parameter of mean teacher (for the exponential moving average).") | ||
parser.add_argument('--mt_alpha2', default=0.995, type=float, help="moving average parameter of mean teacher (for the exponential moving average).") | ||
parser.add_argument('--mt_beta', default=10, type=float, help="coefficient of mt_loss term.") | ||
parser.add_argument('--mt_avg', default="exponential", type=str, help="moving average method, choices:[exponentail(default), simple, double_ema].") | ||
parser.add_argument('--mt_loss_type', default="logits", type=str, help="subject to measure model difference, choices:[embeds, logits(default)].") | ||
``` | ||
|
||
|
||
**VAT** | ||
```bash | ||
sh ./semi_script/bc5cdr_chem_vat.sh GPUTIDS | ||
``` | ||
additional parameters: | ||
1. change `MODEL_NAME` to the baseline model | ||
2. `--vat 1` for enabling mean teacher | ||
3. `--load_weak` and `--remove_labels_from_weak ` for loading data from weak.json and remove their labels. | ||
4. `--rep_train_against_weak N` for upsampling strongly labeled data by `N` times. | ||
|
||
Other parameters | ||
``` | ||
# virtual adversarial training | ||
parser.add_argument('--vat', type = int, default = 0, help = 'virtual adversarial training.') | ||
parser.add_argument('--vat_eps', type = float, default = 1e-3, help = 'perturbation size for virtual adversarial training.') | ||
parser.add_argument('--vat_lambda', type = float, default = 1, help = 'trade off parameter for virtual adversarial training.') | ||
parser.add_argument('--vat_beta', type = float, default = 1, help = 'coefficient of the virtual adversarial training loss term.') | ||
parser.add_argument('--vat_loss_type', default="logits", type=str, help="subject to measure model difference, choices = [embeds, logits(default)].") | ||
``` |
Oops, something went wrong.