Skip to content

Commit a281b03

Browse files
committed
some experimental code
1 parent 7806166 commit a281b03

File tree

3 files changed

+569
-0
lines changed

3 files changed

+569
-0
lines changed

.experimental/local-sgd.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
import os
2+
import os.path as op
3+
import time
4+
5+
6+
7+
from datasets import load_dataset
8+
import torch
9+
from torch.utils.data import DataLoader
10+
import torchmetrics
11+
from transformers import AutoTokenizer
12+
from transformers import AutoModelForSequenceClassification
13+
from watermark import watermark
14+
from accelerate import Accelerator
15+
16+
from local_dataset_utilities import (
17+
download_dataset,
18+
load_dataset_into_to_dataframe,
19+
partition_dataset,
20+
)
21+
from local_dataset_utilities import IMDBDataset
22+
23+
24+
class LocalSGD:
25+
"""
26+
A helper class to support local SGD on top of Accelerator.
27+
It simply runs a given number of updates independently on each device,
28+
and averages model weights every K synchronization step.
29+
Contributed by Leo (Leonid) Boytsov.
30+
Although we are not aware of the true origins of this simple approach,
31+
the idea of local SGD is quite old and goes back to at least:
32+
Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016).
33+
Parallel SGD: When does averaging help?. arXiv preprint arXiv:1606.07365.
34+
We credit the term Local SGD to the following paper (but there might be
35+
earlier references we are not aware of).
36+
Stich, Sebastian Urban. "Local SGD Converges Fast and Communicates Little."
37+
ICLR 2019-International Conference on Learning Representations. No. CONF. 2019.
38+
"""
39+
40+
def __enter__(self):
41+
if self.enabled:
42+
self.model_sync_obj = self.model.no_sync()
43+
self.model_sync_obj.__enter__()
44+
45+
return self
46+
47+
def __exit__(self, type, value, tb):
48+
if self.enabled:
49+
# Average all models on exit
50+
self._sync_and_avg_model_params()
51+
self.model_sync_obj.__exit__(type, value, tb)
52+
53+
def __init__(self, accelerator: Accelerator, model: torch.nn.Module,
54+
local_sgd_steps: int, enabled: bool = True):
55+
"""
56+
Constructor.
57+
Args:
58+
model (`torch.nn.Module):
59+
The model whose parameters we need to average.
60+
accelerator (`Accelerator`):
61+
Accelerator object.
62+
local_sgd_steps (`int`):
63+
A number of local SGD steps (before model parameters are synchronized).
64+
enabled (`bool):
65+
Local SGD is disabled if this parameter set to `False`.
66+
"""
67+
self.enabled = enabled
68+
self.step_qty = 0
69+
if self.enabled:
70+
self.accelerator = accelerator
71+
self.model = model
72+
self.local_sgd_steps = local_sgd_steps
73+
74+
def step(self):
75+
"""
76+
This function makes a "step" and synchronizes model parameters if necessary.
77+
"""
78+
self.step_qty += 1
79+
if not self.enabled:
80+
return
81+
82+
if self.step_qty % self.local_sgd_steps == 0:
83+
self._sync_and_avg_model_params()
84+
85+
def _sync_and_avg_model_params(self):
86+
"""
87+
Synchronize + Average model parameters across all GPUs
88+
"""
89+
import torch.distributed as dist
90+
self.accelerator.wait_for_everyone()
91+
with self.accelerator.autocast():
92+
qty = float(dist.get_world_size())
93+
for prm in self.model.parameters():
94+
dist.all_reduce(prm.data, op=torch.distributed.ReduceOp.SUM)
95+
prm.data /= qty
96+
97+
def tokenize_text(batch):
98+
return tokenizer(batch["text"], truncation=True, padding=True)
99+
100+
101+
def train(num_epochs, model, optimizer, train_loader, val_loader, device):
102+
103+
local_sgd_steps=8
104+
105+
106+
for epoch in range(num_epochs):
107+
train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
108+
109+
with LocalSGD(accelerator=accelerator, model=model, local_sgd_steps=8, enabled=True) as local_sgd:
110+
for batch_idx, batch in enumerate(train_loader):
111+
model.train()
112+
for s in ["input_ids", "attention_mask", "label"]:
113+
batch[s] = batch[s].to(device)
114+
115+
### FORWARD AND BACK PROP
116+
outputs = model(
117+
batch["input_ids"],
118+
attention_mask=batch["attention_mask"],
119+
labels=batch["label"],
120+
)
121+
optimizer.zero_grad()
122+
local_sgd.step()
123+
accelerator.backward(outputs["loss"])
124+
125+
### UPDATE MODEL PARAMETERS
126+
optimizer.step()
127+
128+
### LOGGING
129+
if not batch_idx % 300:
130+
print(
131+
f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {outputs['loss']:.4f}"
132+
)
133+
134+
model.eval()
135+
with torch.no_grad():
136+
predicted_labels = torch.argmax(outputs["logits"], 1)
137+
train_acc.update(predicted_labels, batch["label"])
138+
139+
### MORE LOGGING
140+
with torch.no_grad():
141+
model.eval()
142+
val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
143+
for batch in val_loader:
144+
for s in ["input_ids", "attention_mask", "label"]:
145+
batch[s] = batch[s].to(device)
146+
outputs = model(
147+
batch["input_ids"],
148+
attention_mask=batch["attention_mask"],
149+
labels=batch["label"],
150+
)
151+
predicted_labels = torch.argmax(outputs["logits"], 1)
152+
val_acc.update(predicted_labels, batch["label"])
153+
154+
print(
155+
f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%"
156+
)
157+
158+
159+
if __name__ == "__main__":
160+
print(watermark(packages="torch,lightning,transformers", python=True))
161+
print("Torch CUDA available?", torch.cuda.is_available())
162+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
163+
accelerator = Accelerator()
164+
device = accelerator.device
165+
166+
torch.manual_seed(123)
167+
168+
##########################
169+
### 1 Loading the Dataset
170+
##########################
171+
download_dataset()
172+
df = load_dataset_into_to_dataframe()
173+
if not (op.exists("train.csv") and op.exists("val.csv") and op.exists("test.csv")):
174+
partition_dataset(df)
175+
176+
imdb_dataset = load_dataset(
177+
"csv",
178+
data_files={
179+
"train": "train.csv",
180+
"validation": "val.csv",
181+
"test": "test.csv",
182+
},
183+
)
184+
185+
#########################################
186+
### 2 Tokenization and Numericalization
187+
#########################################
188+
189+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
190+
print("Tokenizer input max length:", tokenizer.model_max_length, flush=True)
191+
print("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True)
192+
193+
print("Tokenizing ...", flush=True)
194+
imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)
195+
del imdb_dataset
196+
imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
197+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
198+
199+
#########################################
200+
### 3 Set Up DataLoaders
201+
#########################################
202+
203+
train_dataset = IMDBDataset(imdb_tokenized, partition_key="train")
204+
val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation")
205+
test_dataset = IMDBDataset(imdb_tokenized, partition_key="test")
206+
207+
train_loader = DataLoader(
208+
dataset=train_dataset,
209+
batch_size=12,
210+
shuffle=True,
211+
num_workers=4,
212+
drop_last=True,
213+
)
214+
215+
val_loader = DataLoader(
216+
dataset=val_dataset,
217+
batch_size=12,
218+
num_workers=2,
219+
drop_last=True,
220+
)
221+
222+
test_loader = DataLoader(
223+
dataset=test_dataset,
224+
batch_size=12,
225+
num_workers=2,
226+
drop_last=True,
227+
)
228+
229+
#########################################
230+
### 4 Initializing the Model
231+
#########################################
232+
233+
model = AutoModelForSequenceClassification.from_pretrained(
234+
"distilbert-base-uncased", num_labels=2
235+
)
236+
237+
model.to(device)
238+
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
239+
240+
#########################################
241+
### 5 Finetuning
242+
#########################################
243+
244+
model, optimizer, train_loader, val_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, val_loader, test_loader)
245+
246+
start = time.time()
247+
train(
248+
num_epochs=3,
249+
model=model,
250+
optimizer=optimizer,
251+
train_loader=train_loader,
252+
val_loader=val_loader,
253+
device=device,
254+
)
255+
256+
end = time.time()
257+
elapsed = end - start
258+
print(f"Time elapsed {elapsed/60:.2f} min")
259+
260+
with torch.no_grad():
261+
model.eval()
262+
test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
263+
for batch in test_loader:
264+
for s in ["input_ids", "attention_mask", "label"]:
265+
batch[s] = batch[s].to(device)
266+
outputs = model(
267+
batch["input_ids"],
268+
attention_mask=batch["attention_mask"],
269+
labels=batch["label"],
270+
)
271+
predicted_labels = torch.argmax(outputs["logits"], 1)
272+
test_acc.update(predicted_labels, batch["label"])
273+
274+
print(f"Test accuracy {test_acc.compute()*100:.2f}%")
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
import sys
3+
import tarfile
4+
import time
5+
6+
import numpy as np
7+
import pandas as pd
8+
from packaging import version
9+
from torch.utils.data import Dataset
10+
from tqdm import tqdm
11+
import urllib
12+
13+
14+
def reporthook(count, block_size, total_size):
15+
global start_time
16+
if count == 0:
17+
start_time = time.time()
18+
return
19+
duration = time.time() - start_time
20+
progress_size = int(count * block_size)
21+
speed = progress_size / (1024.0**2 * duration)
22+
percent = count * block_size * 100.0 / total_size
23+
24+
sys.stdout.write(
25+
f"\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB "
26+
f"| {speed:.2f} MB/s | {duration:.2f} sec elapsed"
27+
)
28+
sys.stdout.flush()
29+
30+
31+
def download_dataset():
32+
source = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
33+
target = "aclImdb_v1.tar.gz"
34+
35+
if os.path.exists(target):
36+
os.remove(target)
37+
38+
if not os.path.isdir("aclImdb") and not os.path.isfile("aclImdb_v1.tar.gz"):
39+
urllib.request.urlretrieve(source, target, reporthook)
40+
41+
if not os.path.isdir("aclImdb"):
42+
43+
with tarfile.open(target, "r:gz") as tar:
44+
tar.extractall()
45+
46+
47+
def load_dataset_into_to_dataframe():
48+
basepath = "aclImdb"
49+
50+
labels = {"pos": 1, "neg": 0}
51+
52+
df = pd.DataFrame()
53+
54+
with tqdm(total=50000) as pbar:
55+
for s in ("test", "train"):
56+
for l in ("pos", "neg"):
57+
path = os.path.join(basepath, s, l)
58+
for file in sorted(os.listdir(path)):
59+
with open(os.path.join(path, file), "r", encoding="utf-8") as infile:
60+
txt = infile.read()
61+
62+
if version.parse(pd.__version__) >= version.parse("1.3.2"):
63+
x = pd.DataFrame(
64+
[[txt, labels[l]]], columns=["review", "sentiment"]
65+
)
66+
df = pd.concat([df, x], ignore_index=False)
67+
68+
else:
69+
df = df.append([[txt, labels[l]]], ignore_index=True)
70+
pbar.update()
71+
df.columns = ["text", "label"]
72+
73+
np.random.seed(0)
74+
df = df.reindex(np.random.permutation(df.index))
75+
76+
print("Class distribution:")
77+
np.bincount(df["label"].values)
78+
79+
return df
80+
81+
82+
def partition_dataset(df):
83+
df_shuffled = df.sample(frac=1, random_state=1).reset_index()
84+
85+
df_train = df_shuffled.iloc[:35_000]
86+
df_val = df_shuffled.iloc[35_000:40_000]
87+
df_test = df_shuffled.iloc[40_000:]
88+
89+
df_train.to_csv("train.csv", index=False, encoding="utf-8")
90+
df_val.to_csv("val.csv", index=False, encoding="utf-8")
91+
df_test.to_csv("test.csv", index=False, encoding="utf-8")
92+
93+
94+
class IMDBDataset(Dataset):
95+
def __init__(self, dataset_dict, partition_key="train"):
96+
self.partition = dataset_dict[partition_key]
97+
98+
def __getitem__(self, index):
99+
return self.partition[index]
100+
101+
def __len__(self):
102+
return self.partition.num_rows

0 commit comments

Comments
 (0)