1
+ import os
2
+ import os .path as op
3
+ import time
4
+
5
+ from datasets import load_dataset
6
+ import lightning as L
7
+ from lightning .pytorch .callbacks import ModelCheckpoint
8
+ from lightning .pytorch .loggers import CSVLogger
9
+ import torch
10
+ from torch .utils .data import DataLoader
11
+ import torchmetrics
12
+ from transformers import AutoTokenizer
13
+ from transformers import AutoModelForSequenceClassification
14
+ from watermark import watermark
15
+
16
+ from local_dataset_utilities import (
17
+ download_dataset ,
18
+ load_dataset_into_to_dataframe ,
19
+ partition_dataset ,
20
+ )
21
+ from local_dataset_utilities import IMDBDataset
22
+
23
+
24
+ def tokenize_text (batch ):
25
+ return tokenizer (batch ["text" ], truncation = True , padding = True )
26
+
27
+
28
+ class LightningModel (L .LightningModule ):
29
+ def __init__ (self , model , learning_rate = 5e-5 ):
30
+ super ().__init__ ()
31
+
32
+ self .learning_rate = learning_rate
33
+ self .model = model
34
+
35
+ self .train_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 )
36
+ self .val_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 )
37
+ self .test_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 )
38
+
39
+ def forward (self , input_ids , attention_mask , labels ):
40
+ return self .model (input_ids , attention_mask = attention_mask , labels = labels )
41
+
42
+ def training_step (self , batch , batch_idx ):
43
+ outputs = self (
44
+ batch ["input_ids" ],
45
+ attention_mask = batch ["attention_mask" ],
46
+ labels = batch ["label" ],
47
+ )
48
+ self .log ("train_loss" , outputs ["loss" ])
49
+ with torch .no_grad ():
50
+ logits = outputs ["logits" ]
51
+ predicted_labels = torch .argmax (logits , 1 )
52
+ self .train_acc (predicted_labels , batch ["label" ])
53
+ self .log ("train_acc" , self .train_acc , on_epoch = True , on_step = False )
54
+ return outputs ["loss" ] # this is passed to the optimizer for training
55
+
56
+ def validation_step (self , batch , batch_idx ):
57
+ outputs = self (
58
+ batch ["input_ids" ],
59
+ attention_mask = batch ["attention_mask" ],
60
+ labels = batch ["label" ],
61
+ )
62
+ self .log ("val_loss" , outputs ["loss" ], prog_bar = True )
63
+
64
+ logits = outputs ["logits" ]
65
+ predicted_labels = torch .argmax (logits , 1 )
66
+ self .val_acc (predicted_labels , batch ["label" ])
67
+ self .log ("val_acc" , self .val_acc , prog_bar = True )
68
+
69
+ def test_step (self , batch , batch_idx ):
70
+ outputs = self (
71
+ batch ["input_ids" ],
72
+ attention_mask = batch ["attention_mask" ],
73
+ labels = batch ["label" ],
74
+ )
75
+
76
+ logits = outputs ["logits" ]
77
+ predicted_labels = torch .argmax (logits , 1 )
78
+ self .test_acc (predicted_labels , batch ["label" ])
79
+ self .log ("accuracy" , self .test_acc , prog_bar = True )
80
+
81
+ def configure_optimizers (self ):
82
+ optimizer = torch .optim .Adam (
83
+ self .trainer .model .parameters (), lr = self .learning_rate
84
+ )
85
+ return optimizer
86
+
87
+
88
+ if __name__ == "__main__" :
89
+ print (watermark (packages = "torch,lightning,transformers" , python = True ), flush = True )
90
+ print ("Torch CUDA available?" , torch .cuda .is_available (), flush = True )
91
+
92
+ torch .manual_seed (123 )
93
+
94
+ ##########################
95
+ ### 1 Loading the Dataset
96
+ ##########################
97
+ download_dataset ()
98
+ df = load_dataset_into_to_dataframe ()
99
+ if not (op .exists ("train.csv" ) and op .exists ("val.csv" ) and op .exists ("test.csv" )):
100
+ partition_dataset (df )
101
+
102
+ imdb_dataset = load_dataset (
103
+ "csv" ,
104
+ data_files = {
105
+ "train" : "train.csv" ,
106
+ "validation" : "val.csv" ,
107
+ "test" : "test.csv" ,
108
+ },
109
+ )
110
+
111
+ #########################################
112
+ ### 2 Tokenization and Numericalization
113
+ ########################################
114
+
115
+ tokenizer = AutoTokenizer .from_pretrained ("distilbert-base-uncased" )
116
+ print ("Tokenizer input max length:" , tokenizer .model_max_length , flush = True )
117
+ print ("Tokenizer vocabulary size:" , tokenizer .vocab_size , flush = True )
118
+
119
+ print ("Tokenizing ..." , flush = True )
120
+ imdb_tokenized = imdb_dataset .map (tokenize_text , batched = True , batch_size = None )
121
+ del imdb_dataset
122
+ imdb_tokenized .set_format ("torch" , columns = ["input_ids" , "attention_mask" , "label" ])
123
+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
124
+
125
+ #########################################
126
+ ### 3 Set Up DataLoaders
127
+ #########################################
128
+
129
+ train_dataset = IMDBDataset (imdb_tokenized , partition_key = "train" )
130
+ val_dataset = IMDBDataset (imdb_tokenized , partition_key = "validation" )
131
+ test_dataset = IMDBDataset (imdb_tokenized , partition_key = "test" )
132
+
133
+ train_loader = DataLoader (
134
+ dataset = train_dataset ,
135
+ batch_size = 12 ,
136
+ shuffle = True ,
137
+ num_workers = 1 ,
138
+ drop_last = True ,
139
+ )
140
+
141
+ val_loader = DataLoader (
142
+ dataset = val_dataset ,
143
+ batch_size = 12 ,
144
+ num_workers = 1 ,
145
+ drop_last = True ,
146
+ )
147
+
148
+ test_loader = DataLoader (
149
+ dataset = test_dataset ,
150
+ batch_size = 12 ,
151
+ num_workers = 1 ,
152
+ drop_last = True ,
153
+ )
154
+
155
+ #########################################
156
+ ### 4 Initializing the Model
157
+ #########################################
158
+
159
+ model = AutoModelForSequenceClassification .from_pretrained (
160
+ "distilbert-base-uncased" , num_labels = 2
161
+ )
162
+
163
+ #########################################
164
+ ### 5 Finetuning
165
+ #########################################
166
+
167
+ lightning_model = LightningModel (model )
168
+
169
+ callbacks = [
170
+ ModelCheckpoint (save_top_k = 1 , mode = "max" , monitor = "val_acc" ) # save top 1 model
171
+ ]
172
+ logger = CSVLogger (save_dir = "logs/" , name = "my-model" )
173
+
174
+ trainer = L .Trainer (
175
+ max_epochs = 3 ,
176
+ callbacks = callbacks ,
177
+ accelerator = "gpu" ,
178
+ devices = 1 ,
179
+ precision = "16" , # <-- NEW
180
+ logger = logger ,
181
+ log_every_n_steps = 10 ,
182
+ deterministic = True ,
183
+ )
184
+
185
+ start = time .time ()
186
+ trainer .fit (
187
+ model = lightning_model ,
188
+ train_dataloaders = train_loader ,
189
+ val_dataloaders = val_loader ,
190
+ )
191
+
192
+ end = time .time ()
193
+ elapsed = end - start
194
+ print (f"Time elapsed { elapsed / 60 :.2f} min" )
195
+
196
+ test_acc = trainer .test (lightning_model , dataloaders = test_loader , ckpt_path = "best" )
197
+ print (test_acc )
198
+
199
+ with open (op .join (trainer .logger .log_dir , "outputs.txt" ), "w" ) as f :
200
+ f .write ((f"Time elapsed { elapsed / 60 :.2f} min\n " ))
201
+ f .write (f"Test acc: { test_acc } " )
0 commit comments