Skip to content

Commit f07d8a9

Browse files
author
zhangchen76
committed
Major update.
1 parent 263c55c commit f07d8a9

23 files changed

+3798
-6
lines changed

.DS_Store

6 KB
Binary file not shown.

README.md

+91-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,98 @@
1-
## StarK <img src="./avengers_ironman01_org.png" width="22" height="22" alt="stark" align=center/>
1+
## StarK <img src="./ironman.png" width="22" height="22" alt="stark" align=center/>
22

3-
This repository contains code for EMNLP 2022 paper titled [Sparse Teachers Can Be Dense with Knowledge]().
3+
This repository contains code for EMNLP 2022 paper titled [Sparse Teachers Can Be Dense with Knowledge](https://arxiv.org/abs/2210.03923).
44

55
**************************** **Updates** ****************************
66

77
<!-- Thanks for your interest in our repo! -->
88

9-
<!-- Probably you will think this as another *"empty"* repo of a preprint paper 🥱.
10-
Wait a minute! The authors are working day and night 💪, to make the code and models available.
11-
We anticipate the code will be out * **in one week** *. -->
9+
* 10/19/22: We released our paper, code, and data. Check it out!
10+
11+
## Quick Links
12+
13+
- [Overview](#overview)
14+
- [Getting Started](#getting-started)
15+
- [Requirements](#requirements)
16+
- [GLUE Data](#glue-data)
17+
- [Training & Evaluation](#training&evaluation)
18+
- [Bugs or Questions?](#bugs-or-questions)
19+
- [Citation](#citation)
20+
21+
## Overview
22+
23+
Recent advances in distilling pretrained language models have discovered that, besides the expressiveness of knowledge, the student-friendliness should be taken into consideration to realize a truly knowledgeable teacher. Based on a pilot study, we find that over-parameterized teachers can produce expressive yet student-unfriendly knowledge and are thus limited in overall knowledgeableness. To remove the parameters that result in student-unfriendliness, we propose a sparse teacher trick under the guidance of an overall knowledgeable score for each teacher parameter. The knowledgeable score is essentially an interpolation of the expressiveness and student-friendliness scores. The aim is to ensure that the expressive parameters are retained while the student-unfriendly ones are removed. Extensive experiments on the GLUE benchmark show that the proposed sparse teachers can be dense with knowledge and lead to students with compelling performance in comparison with a series of competitive baselines.
24+
25+
## Getting Started
26+
27+
### Requirements
28+
29+
- PyTorch
30+
- Numpy
31+
- Transformers
32+
33+
### GLUE Data
34+
35+
Get GLUE data through the [link](https://github.com/nyu-mll/jiant/blob/master/scripts/download_glue_data.py) and put it to the corresponding directory. For example, MRPC dataset should be placed into `datasets/mrpc`.
36+
37+
### Training & Evaluation
38+
39+
The training and evaluation are achieved in several scripts. We provide example scripts as follows.
40+
41+
**Finetuning**
42+
43+
We provide an example of finetuning `bert-base-uncased` on RTE in `scripts/run_finetuning_rte.sh`. We explain some important arguments in following:
44+
* `--model_type`: Variant to use, should be `ft` in the case.
45+
* `--model_path`: Pretrained language models to start with, should be `bert-base-uncased` in the case and can be others as you like.
46+
* `--task_name`: Task to use, should be chosen from `rte`, `mrpc`, `stsb`, `sst2`, `qnli`, `qqp`, `mnli`, and `mnlimm`.
47+
* `--data_type`: Input format to use, default to `combined`.
48+
49+
**Pruning**
50+
51+
We provide and example of pruning a finetuned checkpoint on RTE in `scripts/run_pruning_rte.sh`. The arguments should be self-contained.
52+
53+
**Distillation**
54+
55+
We provide an example of distilling a finetuned teacher to a layer-dropped or parameter-pruned student on RTE in `scripts/run_distillation_rte.sh`. We explain some important arguments in following:
56+
* `--model_type`: Variant to use, should be `kd` in the case.
57+
* `--teacher_model_path`: Teacher models to use, should be the path to the finetuned teacher checkpoint.
58+
* `--student_model_path`: Student models to initialize, should be the path to the pruned/finetuned teacher checkpoint depending on the way you would like to initialize the student.
59+
* `--student_sparsity`: Student sparsity, should be set if you would like to use parameter-pruned student, e.g., 70. Otherwise, this argument should be left blank.
60+
* `--student_layer`: Student layer, should be set if you would like to use layer-dropped student, e.g., 4.
61+
62+
**Teacher Sparsification**
63+
64+
We provide an example of sparsfying the teacher based on the student on RTE in `scripts/run_sparsification_rte.sh`. We explain some important arguments in following:
65+
* `--model_type`: Variant to use, should be `kd` in the case.
66+
* `--teacher_model_path`: Teacher models to use, should be the path to the finetuned teacher checkpoint.
67+
* `--student_model_path`: Student models to use, should be the path to the distilled student checkpoint.
68+
* `--student_sparsity`: Student sparsity, should be set if you would like to use parameter-pruned student, e.g., 70. Otherwise, this argument should be left blank.
69+
* `--student_layer`: Student layer, should be set if you would like to use layer-dropped student, e.g., 4.
70+
* `--lam`: the knowledgeableness tradeoff term to keep a balance between expressiveness and student-friendliness.
71+
72+
**Rewinding**
73+
74+
We provide an example of rewinding the student on RTE in `scripts/run_rewinding_rte.sh`. We explain some important arguments in following:
75+
* `--model_type`: Variant to use, should be `kd` in the case.
76+
* `--teacher_model_path`: Teacher models to use, should be the path to the sparsified teacher checkpoint.
77+
* `--student_model_path`: Student models to initialize, should be the path to the pruned/finetuned teacher checkpoint depending on the way you would like to initialize the student.
78+
* `--student_sparsity`: Student sparsity, should be set if you would like to use parameter-pruned student, e.g., 70. Otherwise, this argument should be left blank.
79+
* `--student_layer`: Student layer, should be set if you would like to use layer-dropped student, e.g., 4.
80+
* `--lam`: the knowledgeableness tradeoff term to keep a balance between expressiveness and student-friendliness. Here, it is just used for folder names.
81+
82+
## Bugs or Questions?
83+
84+
If you have any questions related to the code or the paper, feel free to email Chen (`[email protected]`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!
85+
86+
## Citation
87+
88+
Please cite our paper if you use the code in your work:
89+
90+
```bibtex
91+
@inproceedings{yang2022sparse,
92+
title={Sparse Teachers Can Be Dense with Knowledge},
93+
author={Yang, Yi and Zhang, Chen and Song, Dawei},
94+
booktitle={EMNLP},
95+
year={2022}
96+
}
97+
```
1298

13-
* Coming soon.
File renamed without changes.

data/__init__.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
DataReader -> reading RawData from tsv, json, etc., in general form of Examples. [task-specific]
5+
DataPipeline -> converting Examples to Examples of specific forms, collating Examples as Batches. [model-specific]
6+
"""
7+
8+
9+
from data.readers import (
10+
SST2Reader,
11+
MRPCReader,
12+
STSBReader,
13+
QQPReader,
14+
MNLIReader,
15+
MNLIMMReader,
16+
QNLIReader,
17+
RTEReader,
18+
)
19+
from data.pipelines import (
20+
CombinedDataPipeline,
21+
)
22+
23+
24+
READER_CLASS = {
25+
"sst2": SST2Reader,
26+
"mrpc": MRPCReader,
27+
"stsb": STSBReader,
28+
"qqp": QQPReader,
29+
"mnli": MNLIReader,
30+
"mnlimm": MNLIMMReader,
31+
"qnli": QNLIReader,
32+
"rte": RTEReader,
33+
}
34+
35+
def get_reader_class(task_name):
36+
return READER_CLASS[task_name]
37+
38+
39+
PIPELINE_CLASS = {
40+
"combined": CombinedDataPipeline,
41+
}
42+
43+
def get_pipeline_class(data_type):
44+
return PIPELINE_CLASS[data_type]
45+
46+
47+
import torch
48+
from torch.utils.data import IterableDataset
49+
50+
51+
class Dataset(IterableDataset):
52+
def __init__(self, data, shuffle=True):
53+
super().__init__()
54+
self.data = data
55+
self.shuffle = shuffle
56+
self.num_instances = len(self.data)
57+
58+
def __len__(self):
59+
return self.num_instances
60+
61+
def __iter__(self):
62+
if self.shuffle:
63+
generator = torch.Generator()
64+
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
65+
for idx in torch.randperm(self.num_instances, generator=generator).tolist():
66+
yield self.data[idx]
67+
else:
68+
for idx in range(self.num_instances):
69+
yield self.data[idx]
70+
71+
class DistributedDataset(IterableDataset):
72+
def __init__(self, data, num_replicas=None, rank=None, shuffle=True):
73+
super().__init__()
74+
self.data = data
75+
if num_replicas is None:
76+
if not dist.is_available():
77+
raise RuntimeError("Requires distributed package to be available")
78+
num_replicas = dist.get_world_size()
79+
if rank is None:
80+
if not dist.is_available():
81+
raise RuntimeError("Requires distributed package to be available")
82+
rank = dist.get_rank()
83+
if rank >= num_replicas or rank < 0:
84+
raise ValueError(
85+
"Invalid rank {}, rank should be in the interval"
86+
" [0, {}]".format(rank, num_replicas - 1))
87+
self.num_replicas = num_replicas
88+
self.rank = rank
89+
self.shuffle = shuffle
90+
# Do ceiling to make the data evenly divisible among devices.
91+
self.num_instances = math.ceil(len(self.data) / self.num_replicas)
92+
self.total_num_instances = self.num_instances * self.num_replicas
93+
94+
def __len__(self):
95+
return self.num_instances
96+
97+
def __iter__(self):
98+
if self.shuffle:
99+
generator = torch.Generator()
100+
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
101+
indices = torch.randperm(self.num_instances, generator=generator).tolist()
102+
else:
103+
indices = list(range(self.num_instances))
104+
105+
num_padding_instances = self.total_num_instances - len(indices)
106+
# Is the logic necessary?
107+
if num_padding_instances <= len(indices):
108+
indices += indices[:num_padding_instances]
109+
else:
110+
indices += (indices * math.ceil(num_padding_instances / len(indices)))[:num_padding_instances]
111+
112+
assert len(indices) == self.num_total_instances
113+
114+
# Subsample.
115+
indices = indices[self.rank:self.num_total_instances:self.num_replicas]
116+
assert len(indices) == self.num_instances
117+
118+
for idx in indices:
119+
yield self.data[idx]

0 commit comments

Comments
 (0)