1
+ import os
2
+ import os .path as op
3
+ import time
4
+
5
+
6
+
7
+ from datasets import load_dataset
8
+ import torch
9
+ from torch .utils .data import DataLoader
10
+ import torchmetrics
11
+ from transformers import AutoTokenizer
12
+ from transformers import AutoModelForSequenceClassification
13
+ from watermark import watermark
14
+ from accelerate import Accelerator
15
+
16
+ from local_dataset_utilities import (
17
+ download_dataset ,
18
+ load_dataset_into_to_dataframe ,
19
+ partition_dataset ,
20
+ )
21
+ from local_dataset_utilities import IMDBDataset
22
+
23
+
24
+ class LocalSGD :
25
+ """
26
+ A helper class to support local SGD on top of Accelerator.
27
+ It simply runs a given number of updates independently on each device,
28
+ and averages model weights every K synchronization step.
29
+ Contributed by Leo (Leonid) Boytsov.
30
+ Although we are not aware of the true origins of this simple approach,
31
+ the idea of local SGD is quite old and goes back to at least:
32
+ Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016).
33
+ Parallel SGD: When does averaging help?. arXiv preprint arXiv:1606.07365.
34
+ We credit the term Local SGD to the following paper (but there might be
35
+ earlier references we are not aware of).
36
+ Stich, Sebastian Urban. "Local SGD Converges Fast and Communicates Little."
37
+ ICLR 2019-International Conference on Learning Representations. No. CONF. 2019.
38
+ """
39
+
40
+ def __enter__ (self ):
41
+ if self .enabled :
42
+ self .model_sync_obj = self .model .no_sync ()
43
+ self .model_sync_obj .__enter__ ()
44
+
45
+ return self
46
+
47
+ def __exit__ (self , type , value , tb ):
48
+ if self .enabled :
49
+ # Average all models on exit
50
+ self ._sync_and_avg_model_params ()
51
+ self .model_sync_obj .__exit__ (type , value , tb )
52
+
53
+ def __init__ (self , accelerator : Accelerator , model : torch .nn .Module ,
54
+ local_sgd_steps : int , enabled : bool = True ):
55
+ """
56
+ Constructor.
57
+ Args:
58
+ model (`torch.nn.Module):
59
+ The model whose parameters we need to average.
60
+ accelerator (`Accelerator`):
61
+ Accelerator object.
62
+ local_sgd_steps (`int`):
63
+ A number of local SGD steps (before model parameters are synchronized).
64
+ enabled (`bool):
65
+ Local SGD is disabled if this parameter set to `False`.
66
+ """
67
+ self .enabled = enabled
68
+ self .step_qty = 0
69
+ if self .enabled :
70
+ self .accelerator = accelerator
71
+ self .model = model
72
+ self .local_sgd_steps = local_sgd_steps
73
+
74
+ def step (self ):
75
+ """
76
+ This function makes a "step" and synchronizes model parameters if necessary.
77
+ """
78
+ self .step_qty += 1
79
+ if not self .enabled :
80
+ return
81
+
82
+ if self .step_qty % self .local_sgd_steps == 0 :
83
+ self ._sync_and_avg_model_params ()
84
+
85
+ def _sync_and_avg_model_params (self ):
86
+ """
87
+ Synchronize + Average model parameters across all GPUs
88
+ """
89
+ import torch .distributed as dist
90
+ self .accelerator .wait_for_everyone ()
91
+ with self .accelerator .autocast ():
92
+ qty = float (dist .get_world_size ())
93
+ for prm in self .model .parameters ():
94
+ dist .all_reduce (prm .data , op = torch .distributed .ReduceOp .SUM )
95
+ prm .data /= qty
96
+
97
+ def tokenize_text (batch ):
98
+ return tokenizer (batch ["text" ], truncation = True , padding = True )
99
+
100
+
101
+ def train (num_epochs , model , optimizer , train_loader , val_loader , device ):
102
+
103
+ local_sgd_steps = 8
104
+
105
+
106
+ for epoch in range (num_epochs ):
107
+ train_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
108
+
109
+ with LocalSGD (accelerator = accelerator , model = model , local_sgd_steps = 8 , enabled = True ) as local_sgd :
110
+ for batch_idx , batch in enumerate (train_loader ):
111
+ model .train ()
112
+ for s in ["input_ids" , "attention_mask" , "label" ]:
113
+ batch [s ] = batch [s ].to (device )
114
+
115
+ ### FORWARD AND BACK PROP
116
+ outputs = model (
117
+ batch ["input_ids" ],
118
+ attention_mask = batch ["attention_mask" ],
119
+ labels = batch ["label" ],
120
+ )
121
+ optimizer .zero_grad ()
122
+ local_sgd .step ()
123
+ accelerator .backward (outputs ["loss" ])
124
+
125
+ ### UPDATE MODEL PARAMETERS
126
+ optimizer .step ()
127
+
128
+ ### LOGGING
129
+ if not batch_idx % 300 :
130
+ print (
131
+ f"Epoch: { epoch + 1 :04d} /{ num_epochs :04d} | Batch { batch_idx :04d} /{ len (train_loader ):04d} | Loss: { outputs ['loss' ]:.4f} "
132
+ )
133
+
134
+ model .eval ()
135
+ with torch .no_grad ():
136
+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
137
+ train_acc .update (predicted_labels , batch ["label" ])
138
+
139
+ ### MORE LOGGING
140
+ with torch .no_grad ():
141
+ model .eval ()
142
+ val_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
143
+ for batch in val_loader :
144
+ for s in ["input_ids" , "attention_mask" , "label" ]:
145
+ batch [s ] = batch [s ].to (device )
146
+ outputs = model (
147
+ batch ["input_ids" ],
148
+ attention_mask = batch ["attention_mask" ],
149
+ labels = batch ["label" ],
150
+ )
151
+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
152
+ val_acc .update (predicted_labels , batch ["label" ])
153
+
154
+ print (
155
+ f"Epoch: { epoch + 1 :04d} /{ num_epochs :04d} | Train acc.: { train_acc .compute ()* 100 :.2f} % | Val acc.: { val_acc .compute ()* 100 :.2f} %"
156
+ )
157
+
158
+
159
+ if __name__ == "__main__" :
160
+ print (watermark (packages = "torch,lightning,transformers" , python = True ))
161
+ print ("Torch CUDA available?" , torch .cuda .is_available ())
162
+ device = "cuda:0" if torch .cuda .is_available () else "cpu"
163
+ accelerator = Accelerator ()
164
+ device = accelerator .device
165
+
166
+ torch .manual_seed (123 )
167
+
168
+ ##########################
169
+ ### 1 Loading the Dataset
170
+ ##########################
171
+ download_dataset ()
172
+ df = load_dataset_into_to_dataframe ()
173
+ if not (op .exists ("train.csv" ) and op .exists ("val.csv" ) and op .exists ("test.csv" )):
174
+ partition_dataset (df )
175
+
176
+ imdb_dataset = load_dataset (
177
+ "csv" ,
178
+ data_files = {
179
+ "train" : "train.csv" ,
180
+ "validation" : "val.csv" ,
181
+ "test" : "test.csv" ,
182
+ },
183
+ )
184
+
185
+ #########################################
186
+ ### 2 Tokenization and Numericalization
187
+ #########################################
188
+
189
+ tokenizer = AutoTokenizer .from_pretrained ("distilbert-base-uncased" )
190
+ print ("Tokenizer input max length:" , tokenizer .model_max_length , flush = True )
191
+ print ("Tokenizer vocabulary size:" , tokenizer .vocab_size , flush = True )
192
+
193
+ print ("Tokenizing ..." , flush = True )
194
+ imdb_tokenized = imdb_dataset .map (tokenize_text , batched = True , batch_size = None )
195
+ del imdb_dataset
196
+ imdb_tokenized .set_format ("torch" , columns = ["input_ids" , "attention_mask" , "label" ])
197
+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
198
+
199
+ #########################################
200
+ ### 3 Set Up DataLoaders
201
+ #########################################
202
+
203
+ train_dataset = IMDBDataset (imdb_tokenized , partition_key = "train" )
204
+ val_dataset = IMDBDataset (imdb_tokenized , partition_key = "validation" )
205
+ test_dataset = IMDBDataset (imdb_tokenized , partition_key = "test" )
206
+
207
+ train_loader = DataLoader (
208
+ dataset = train_dataset ,
209
+ batch_size = 12 ,
210
+ shuffle = True ,
211
+ num_workers = 4 ,
212
+ drop_last = True ,
213
+ )
214
+
215
+ val_loader = DataLoader (
216
+ dataset = val_dataset ,
217
+ batch_size = 12 ,
218
+ num_workers = 2 ,
219
+ drop_last = True ,
220
+ )
221
+
222
+ test_loader = DataLoader (
223
+ dataset = test_dataset ,
224
+ batch_size = 12 ,
225
+ num_workers = 2 ,
226
+ drop_last = True ,
227
+ )
228
+
229
+ #########################################
230
+ ### 4 Initializing the Model
231
+ #########################################
232
+
233
+ model = AutoModelForSequenceClassification .from_pretrained (
234
+ "distilbert-base-uncased" , num_labels = 2
235
+ )
236
+
237
+ model .to (device )
238
+ optimizer = torch .optim .Adam (model .parameters (), lr = 5e-5 )
239
+
240
+ #########################################
241
+ ### 5 Finetuning
242
+ #########################################
243
+
244
+ model , optimizer , train_loader , val_loader , test_loader = accelerator .prepare (model , optimizer , train_loader , val_loader , test_loader )
245
+
246
+ start = time .time ()
247
+ train (
248
+ num_epochs = 3 ,
249
+ model = model ,
250
+ optimizer = optimizer ,
251
+ train_loader = train_loader ,
252
+ val_loader = val_loader ,
253
+ device = device ,
254
+ )
255
+
256
+ end = time .time ()
257
+ elapsed = end - start
258
+ print (f"Time elapsed { elapsed / 60 :.2f} min" )
259
+
260
+ with torch .no_grad ():
261
+ model .eval ()
262
+ test_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
263
+ for batch in test_loader :
264
+ for s in ["input_ids" , "attention_mask" , "label" ]:
265
+ batch [s ] = batch [s ].to (device )
266
+ outputs = model (
267
+ batch ["input_ids" ],
268
+ attention_mask = batch ["attention_mask" ],
269
+ labels = batch ["label" ],
270
+ )
271
+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
272
+ test_acc .update (predicted_labels , batch ["label" ])
273
+
274
+ print (f"Test accuracy { test_acc .compute ()* 100 :.2f} %" )
0 commit comments