-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d4108fb
commit 61894a9
Showing
13 changed files
with
1,103 additions
and
0 deletions.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
openfl-tutorials/experimental/Federeated_Pytorch_LLM_Horovod.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
|
||
import openfl.native as fx | ||
import sys | ||
|
||
sys.path.append("openfl/openfl-workspace/torch_llm") | ||
from src.pt_model import LLMTaskRunner | ||
from src.ptglue_inmemory import GlueMrpcFederatedDataLoader | ||
import openfl.interface.workspace as workspace | ||
import os | ||
import subprocess | ||
|
||
WORKSPACE_PREFIX = os.path.join(os.path.expanduser("~"), ".local", "workspace") | ||
|
||
# Run openfl-tutorials/experimental/setup_env.shsetup_env.sh in your venv to setup horovod dependancies | ||
# set up the venv in each node | ||
# make dir ~/.local/workspace in each node | ||
# horovod requires password less ssh login, you can learn how to set it up here: http://www.linuxproblem.org/art_9.html | ||
|
||
# You should set the following ENVIROMENTAL VARIABLES for horovod | ||
#OPENFL_HOROVOD_DEMO_NP=STR with number of processes to run eg. "4" | ||
#OPENFL_HOROVOD_DEMO_NICS=STR with the common network interface name to use with all nodes eg. "en01" | ||
#OPENFL_HOROVOD_DEMO_LOCALHOSTIP=STR with the IP address of the local node eg. "ip1" | ||
#OPENFL_HOROVOD_DEMO_HOSTS=STR with the IP address of the each node and number of slots eg. "ip1:2,ip2,2" | ||
|
||
NP = os.environ.get('OPENFL_HOROVOD_DEMO_NP','4') | ||
NETWORK_INTERFACES = os.environ.get('OPENFL_HOROVOD_DEMO_NICS','localhost') | ||
LOCAL_HOST = os.environ.get('OPENFL_HOROVOD_DEMO_LOCALHOSTIP','localhost') | ||
HOSTS = os.environ.get('OPENFL_HOROVOD_DEMO_HOSTS','localhost:4') | ||
|
||
print('NP:', NP) | ||
print('NETWORK_INTERFACES:', NETWORK_INTERFACES) | ||
print('LOCAL_HOST:', LOCAL_HOST) | ||
print('HOSTS:', HOSTS) | ||
|
||
def propogate_workspace(): | ||
remote_hosts = [ | ||
i.split(":")[0] for i in HOSTS.split(",") if i.split(":")[0] != LOCAL_HOST | ||
] | ||
for rem_host in remote_hosts: | ||
result = subprocess.run( | ||
[ | ||
"scp", | ||
"-r", | ||
WORKSPACE_PREFIX, | ||
rem_host | ||
+ ":" + | ||
WORKSPACE_PREFIX.replace('workspace',''), | ||
], | ||
capture_output=True, | ||
) | ||
print([ | ||
"scp", | ||
"-r", | ||
WORKSPACE_PREFIX, | ||
rem_host | ||
+ ":" + | ||
WORKSPACE_PREFIX, | ||
]) | ||
if result.returncode != 0: | ||
raise RuntimeError(result.stderr) | ||
|
||
def main(): | ||
print(WORKSPACE_PREFIX) | ||
log_level = "INFO" | ||
log_file = None | ||
workspace.create(WORKSPACE_PREFIX, "torch_llm") | ||
os.chdir(WORKSPACE_PREFIX) | ||
sys.path.append(WORKSPACE_PREFIX) | ||
propogate_workspace() | ||
fx.setup_logging(level=log_level, log_file=log_file) | ||
num_collaborators = 1 | ||
|
||
collaborator_models = [ | ||
LLMTaskRunner( | ||
data_loader=GlueMrpcFederatedDataLoader( | ||
data_slice, 32, collaborator_count=num_collaborators | ||
) | ||
) | ||
for data_slice in range(num_collaborators) | ||
] | ||
collaborators = { | ||
"one": collaborator_models[0], | ||
} | ||
|
||
# Collaborator one's data | ||
for i, model in enumerate(collaborator_models): | ||
print( | ||
f"Collaborator {i}'s training data size: {len(model.data_loader.train_set)}" | ||
) | ||
print( | ||
f"Collaborator {i}'s validation data size: {len(model.data_loader.valid_set)}\n" | ||
) | ||
final_fl_model = fx.run_experiment( | ||
collaborators, | ||
{"aggregator.settings.rounds_to_train": 5, "tasks.train.kwargs.epochs": 1}, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
sentencepiece | ||
accelerate | ||
jupyter | ||
huggingface_hub | ||
peft | ||
transformers[torch] | ||
datasets | ||
evaluate | ||
seqeval | ||
torch | ||
torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
pip install -U pip --no-cache | ||
pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 | ||
pip install -r requirments_horovod.txt --no-cache | ||
HOROVOD_WITH_PYTORCH=1 HOROVOD_WITHOUT_MPI=1 pip install horovod[pytorch] --no-cache | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
current_plan_name: default | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
collaborators: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
## Copyright (C) 2020-2021 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
# all keys under 'collaborators' corresponds to a specific colaborator name the corresponding dictionary has data_name, data_path pairs. | ||
# Note that in the mnist case we do not store the data locally, and the data_path is used to pass an integer that helps the data object | ||
# construct the shard of the mnist dataset to be use for this collaborator. | ||
|
||
# collaborator_name ,data_directory_path | ||
one,1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
../../workspace/plan/defaults | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
aggregator : | ||
defaults : plan/defaults/aggregator.yaml | ||
template : openfl.component.Aggregator | ||
settings : | ||
init_state_path : save/torch_llm_init.pbuf | ||
best_state_path : save/torch_llm_best.pbuf | ||
last_state_path : save/torch_llm_last.pbuf | ||
rounds_to_train : 5 | ||
log_metric_callback : | ||
template : src.glue_utils.write_metric | ||
|
||
|
||
collaborator : | ||
defaults : plan/defaults/collaborator.yaml | ||
template : openfl.component.Collaborator | ||
settings : | ||
delta_updates : false | ||
opt_treatment : RESET | ||
|
||
data_loader : | ||
defaults : plan/defaults/data_loader.yaml | ||
template : src.ptglue_inmemory.GlueMrpcFederatedDataLoader | ||
settings : | ||
collaborator_count : 2 | ||
data_group_name : mnist | ||
batch_size : 256 | ||
|
||
task_runner : | ||
defaults : plan/defaults/task_runner.yaml | ||
template : src.pt_model.LLMTaskRunner | ||
|
||
network : | ||
defaults : plan/defaults/network.yaml | ||
|
||
assigner : | ||
defaults : plan/defaults/assigner.yaml | ||
|
||
tasks : | ||
defaults : plan/defaults/tasks_torch.yaml | ||
|
||
compression_pipeline : | ||
defaults : plan/defaults/compression_pipeline.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
torch | ||
tensorboard | ||
wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability | ||
sentencepiece | ||
accelerate | ||
jupyter | ||
huggingface_hub | ||
peft | ||
transformers | ||
datasets | ||
evaluate | ||
seqeval | ||
horovod | ||
torch | ||
torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""You may copy this file as the starting point of your own model.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""You may copy this file as the starting point of your own model.""" | ||
|
||
from logging import getLogger | ||
|
||
import horovod.torch as hvd | ||
import torch | ||
from datasets import Dataset, load_dataset | ||
from torch.utils.data import DataLoader | ||
from torch.utils.tensorboard import SummaryWriter | ||
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, | ||
DataCollatorWithPadding, get_scheduler) | ||
from openfl.utilities.data_splitters import EqualNumPyDataSplitter | ||
from transformers import DataCollatorWithPadding | ||
|
||
logger = getLogger(__name__) | ||
|
||
writer = None | ||
|
||
def get_writer(): | ||
"""Create global writer object.""" | ||
global writer | ||
if not writer: | ||
writer = SummaryWriter('./logs/llm', flush_secs=5) | ||
|
||
|
||
def write_metric(node_name, task_name, metric_name, metric, round_number): | ||
"""Write metric callback.""" | ||
get_writer() | ||
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number) | ||
|
||
def get_glue_mrpc_dataset(tokenizer): | ||
dataset = load_dataset("glue", "mrpc") | ||
|
||
def tokenize_function(examples): | ||
# max_length=None => use the model max length (it's actually the default) | ||
outputs = tokenizer( | ||
examples["sentence1"], | ||
examples["sentence2"], | ||
truncation=True, | ||
max_length=None, | ||
) | ||
return outputs | ||
|
||
tokenized_datasets = dataset.map( | ||
tokenize_function, | ||
batched=True, | ||
remove_columns=["idx", "sentence1", "sentence2"], | ||
) | ||
tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | ||
tokenized_datasets.set_format("torch") | ||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="longest") | ||
return data_collator, tokenized_datasets | ||
|
||
|
||
class GlueMrpc(Dataset): | ||
""" | ||
Has 5.8k pairs of sentences with annotations if the two sentences are equivalent | ||
""" | ||
|
||
def get_shape(self): | ||
if not hasattr(self, "saved_shape"): | ||
self.saved_shape = max([len(i) for i in self.data["input_ids"]]) | ||
return self.saved_shape | ||
|
||
|
||
def get_dataset(base_model_name="roberta-base", padding_side="right"): | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
base_model_name, padding_side=padding_side | ||
) | ||
if getattr(tokenizer, "pad_token_id") is None: | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
data_collator, tokenized_datasets = get_glue_mrpc_dataset(tokenizer) | ||
|
||
train_set = GlueMrpc.from_dict(tokenized_datasets["train"].to_dict()) | ||
valid_set = GlueMrpc.from_dict(tokenized_datasets["test"].to_dict()) | ||
return train_set, valid_set, data_collator |
Oops, something went wrong.