Skip to content

Commit f3b7270

Browse files
committed
update prompt-tuning codes for nlp tasks
1 parent 1bf177c commit f3b7270

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+63
-63
lines changed

.github/workflows/test_prompt.yml renamed to .github/workflows/test_fedsp.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: UnitTests for Prompt Tuning
1+
name: UnitTests for FedSP
22

33
on:
44
schedule:
@@ -44,7 +44,7 @@ jobs:
4444
- name: Test Prompt Tuning
4545
run: |
4646
python ../../main.py \
47-
--cfg federatedscope/nlp/prompt_tuning/baseline/config_alter_train.yaml \
47+
--cfg federatedscope/nlp/fedsp/baseline/config_alter_train.yaml \
4848
data.dataset_name arc_challenge \
4949
data.batch_size 1 \
5050
data.max_seq_len 32 \
@@ -53,8 +53,8 @@ jobs:
5353
federate.total_round_num 2 \
5454
federate.make_global_train True \
5555
federate.pl_init_kd True \
56-
federate.pl_kd_cfg_file federatedscope/nlp/prompt_tuning/baseline/config_init_kd_test.yaml \
57-
federate.pl_global_cfg_file federatedscope/nlp/prompt_tuning/baseline/config_global.2.yaml \
56+
federate.pl_kd_cfg_file federatedscope/nlp/fedsp/baseline/config_init_kd_test.yaml \
57+
federate.pl_global_cfg_file federatedscope/nlp/fedsp/baseline/config_global.2.yaml \
5858
model.use_fp16 True \
5959
model.model_type facebook/opt-1.3b \
6060
model.use_prefix_prj False \

federatedscope/core/auxiliaries/data_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
'dblp_org', 'csbm.*?', 'fb15k-237', 'wn18', 'adult', 'abalone',
3232
'credit', 'blog'
3333
], # Dummy for FL dataset
34-
'RawDataTranslator': ['hetero_nlp_tasks', 'pl_data'],
34+
'RawDataTranslator': ['hetero_nlp_tasks', 'fedsp_data'],
3535
}
3636
DATA_TRANS_MAP = RegexInverseMap(TRANS_DATA_MAP, None)
3737

federatedscope/core/auxiliaries/model_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ def get_model(model_config, local_data=None, backend='torch', role='client'):
197197
elif model_config.type.lower() in ['atc_model']:
198198
from federatedscope.nlp.hetero_tasks.model import ATCModel
199199
model = ATCModel(model_config)
200-
elif model_config.type.lower() in ['pl_model']:
201-
from federatedscope.nlp.prompt_tuning.model import PLModel
202-
model = PLModel(model_config, role=role)
200+
elif model_config.type.lower() in ['fedsp_model']:
201+
from federatedscope.nlp.fedsp.model import FedSPModel
202+
model = FedSPModel(model_config, role=role)
203203
else:
204204
raise ValueError('Model {} is not provided'.format(model_config.type))
205205

federatedscope/core/auxiliaries/trainer_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"cltrainer": "CLTrainer",
3030
"lptrainer": "LPTrainer",
3131
"atc_trainer": "ATCTrainer",
32-
"pl_trainer": "PLTrainer",
32+
"fedsp_trainer": "FedSPTrainer",
3333
}
3434

3535

@@ -158,7 +158,7 @@ def get_trainer(model=None,
158158
dict_path = "federatedscope.mf.trainer.trainer"
159159
elif config.trainer.type.lower() in ['atc_trainer']:
160160
dict_path = "federatedscope.nlp.hetero_tasks.trainer"
161-
elif config.trainer.type.lower() in ["pl_trainer"]:
161+
elif config.trainer.type.lower() in ["fedsp_trainer"]:
162162
dict_path = "federatedscope.nlp.prompt_tuning.trainer"
163163
else:
164164
raise ValueError

federatedscope/core/auxiliaries/worker_builder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def get_client_cls(cfg):
7676
from federatedscope.nlp.hetero_tasks.worker import ATCClient
7777
return ATCClient
7878

79-
if cfg.trainer.type.lower() == 'pl_trainer':
80-
from federatedscope.nlp.prompt_tuning.worker import PLClient
81-
return PLClient
79+
if cfg.trainer.type.lower() == 'fedsp_trainer':
80+
from federatedscope.nlp.fedsp.worker import FedSPClient
81+
return FedSPClient
8282

8383
if cfg.federate.method.lower() in constants.CLIENTS_TYPE:
8484
client_type = constants.CLIENTS_TYPE[cfg.federate.method.lower()]
@@ -199,9 +199,9 @@ def get_server_cls(cfg):
199199
from federatedscope.nlp.hetero_tasks.worker import ATCServer
200200
return ATCServer
201201

202-
if cfg.trainer.type.lower() == 'pl_trainer':
203-
from federatedscope.nlp.prompt_tuning.worker import PLServer
204-
return PLServer
202+
if cfg.trainer.type.lower() == 'fedsp_trainer':
203+
from federatedscope.nlp.fedsp.worker import FedSPServer
204+
return FedSPServer
205205

206206
if cfg.federate.method.lower() in constants.SERVER_TYPE:
207207
server_type = constants.SERVER_TYPE[cfg.federate.method.lower()]

federatedscope/core/data/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def load_dataset(config, client_cfgs=None):
9494
from federatedscope.nlp.hetero_tasks.dataloader import \
9595
load_heteroNLP_data
9696
dataset, modified_config = load_heteroNLP_data(config, client_cfgs)
97-
elif 'pl_data' in config.data.type.lower():
98-
from federatedscope.nlp.prompt_tuning.dataloader import load_pl_data
99-
dataset, modified_config = load_pl_data(config)
97+
elif 'fedsp_data' in config.data.type.lower():
98+
from federatedscope.nlp.fedsp.dataloader import load_fedsp_data
99+
dataset, modified_config = load_fedsp_data(config)
100100
elif '@' in config.data.type.lower():
101101
from federatedscope.core.data.utils import load_external_data
102102
dataset, modified_config = load_external_data(config)

federatedscope/nlp/prompt_tuning/README.md renamed to federatedscope/nlp/fedsp/README.md

Lines changed: 4 additions & 4 deletions
File renamed without changes.

federatedscope/nlp/prompt_tuning/baseline/config_alter_train.yaml renamed to federatedscope/nlp/fedsp/baseline/config_alter_train.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@ federate:
1111
pl_alter_train: True
1212
pl_save_to: ckpt
1313
data:
14-
type: pl_data
14+
type: fedsp_data
1515
batch_size: 16
1616
max_seq_len: 1024
1717
num_workers: 0
1818
model:
19-
type: pl_model
19+
type: fedsp_model
2020
model_type: facebook/opt-1.3b
2121
num_server_layers: 24
2222
prefix_hidden_size: 512
2323
server_freeze_param: ['model']
2424
alter_model_param: ['model']
2525
alter_prompt_param: ['prefix_encoder']
2626
trainer:
27-
type: pl_trainer
27+
type: fedsp_trainer
2828
train:
2929
batch_or_epoch: batch
3030
local_update_steps: 10

federatedscope/nlp/prompt_tuning/baseline/config_freeze.yaml renamed to federatedscope/nlp/fedsp/baseline/config_freeze.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ federate:
1010
make_global_eval: True
1111
pl_save_to: ckpt
1212
data:
13-
type: pl_data
13+
type: fedsp_data
1414
batch_size: 16
1515
max_seq_len: 1024
1616
num_workers: 0
1717
model:
18-
type: pl_model
18+
type: fedsp_model
1919
model_type: facebook/opt-1.3b
2020
num_server_layers: 24
2121
server_freeze_param: ['model']
2222
client_freeze_param: ['model']
2323
trainer:
24-
type: pl_trainer
24+
type: fedsp_trainer
2525
train:
2626
batch_or_epoch: batch
2727
local_update_steps: 10

0 commit comments

Comments
 (0)