1+ import os
2+ import os .path as op
3+ import time
4+
5+ from datasets import load_dataset
6+ import matplotlib .pyplot as plt
7+ import pandas as pd
8+ import torch
9+ from torch .utils .data import DataLoader
10+ import torchmetrics
11+ from transformers import AutoTokenizer
12+ from transformers import AutoModelForSequenceClassification
13+ from watermark import watermark
14+
15+ from local_dataset_utilities import (
16+ download_dataset ,
17+ load_dataset_into_to_dataframe ,
18+ partition_dataset ,
19+ )
20+ from local_dataset_utilities import IMDBDataset
21+
22+
23+ def tokenize_text (batch ):
24+ return tokenizer (batch ["text" ], truncation = True , padding = True )
25+
26+
27+ def plot_logs (log_dir ):
28+ metrics = pd .read_csv (op .join (log_dir , "metrics.csv" ))
29+
30+ aggreg_metrics = []
31+ agg_col = "epoch"
32+ for i , dfg in metrics .groupby (agg_col ):
33+ agg = dict (dfg .mean ())
34+ agg [agg_col ] = i
35+ aggreg_metrics .append (agg )
36+
37+ df_metrics = pd .DataFrame (aggreg_metrics )
38+ df_metrics [["train_loss" , "val_loss" ]].plot (
39+ grid = True , legend = True , xlabel = "Epoch" , ylabel = "Loss"
40+ )
41+ plt .savefig (op .join (log_dir , "loss.pdf" ))
42+
43+ df_metrics [["train_acc" , "val_acc" ]].plot (
44+ grid = True , legend = True , xlabel = "Epoch" , ylabel = "Accuracy"
45+ )
46+ plt .savefig (op .join (log_dir , "acc.pdf" ))
47+
48+
49+ def train (num_epochs , model , optimizer , train_loader , val_loader , device ):
50+ for epoch in range (num_epochs ):
51+ train_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
52+
53+ for batch_idx , batch in enumerate (train_loader ):
54+ model .train ()
55+ for s in ["input_ids" , "attention_mask" , "label" ]:
56+ batch [s ] = batch [s ].to (device )
57+
58+ ### FORWARD AND BACK PROP
59+ outputs = model (
60+ batch ["input_ids" ],
61+ attention_mask = batch ["attention_mask" ],
62+ labels = batch ["label" ],
63+ )
64+ optimizer .zero_grad ()
65+ outputs ["loss" ].backward ()
66+
67+ ### UPDATE MODEL PARAMETERS
68+ optimizer .step ()
69+
70+ ### LOGGING
71+ if not batch_idx % 300 :
72+ print (
73+ f"Epoch: { epoch + 1 :04d} /{ num_epochs :04d} | Batch { batch_idx :04d} /{ len (train_loader ):04d} | Loss: { outputs ['loss' ]:.4f} "
74+ )
75+
76+ model .eval ()
77+ with torch .no_grad ():
78+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
79+ train_acc .update (predicted_labels , batch ["label" ])
80+
81+ ### MORE LOGGING
82+ with torch .no_grad ():
83+ model .eval ()
84+ val_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
85+ for batch in val_loader :
86+ for s in ["input_ids" , "attention_mask" , "label" ]:
87+ batch [s ] = batch [s ].to (device )
88+ outputs = model (
89+ batch ["input_ids" ],
90+ attention_mask = batch ["attention_mask" ],
91+ labels = batch ["label" ],
92+ )
93+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
94+ val_acc .update (predicted_labels , batch ["label" ])
95+
96+ print (
97+ f"Epoch: { epoch + 1 :04d} /{ num_epochs :04d} | Train acc.: { train_acc .compute ()* 100 :.2f} % | Val acc.: { val_acc .compute ()* 100 :.2f} %"
98+ )
99+
100+
101+ if __name__ == "__main__" :
102+ print (watermark (packages = "torch,lightning,transformers" , python = True ))
103+ print ("Torch CUDA available?" , torch .cuda .is_available ())
104+ device = "cuda:0" if torch .cuda .is_available () else "cpu"
105+
106+ torch .manual_seed (123 )
107+
108+ ##########################
109+ ### 1 Loading the Dataset
110+ ##########################
111+ download_dataset ()
112+ df = load_dataset_into_to_dataframe ()
113+ if not (op .exists ("train.csv" ) and op .exists ("val.csv" ) and op .exists ("test.csv" )):
114+ partition_dataset (df )
115+
116+ imdb_dataset = load_dataset (
117+ "csv" ,
118+ data_files = {
119+ "train" : "train.csv" ,
120+ "validation" : "val.csv" ,
121+ "test" : "test.csv" ,
122+ },
123+ )
124+
125+ #########################################
126+ ### 2 Tokenization and Numericalization
127+ #########################################
128+
129+ tokenizer = AutoTokenizer .from_pretrained ("distilbert-base-uncased" )
130+ print ("Tokenizer input max length:" , tokenizer .model_max_length , flush = True )
131+ print ("Tokenizer vocabulary size:" , tokenizer .vocab_size , flush = True )
132+
133+ print ("Tokenizing ..." , flush = True )
134+ imdb_tokenized = imdb_dataset .map (tokenize_text , batched = True , batch_size = None )
135+ del imdb_dataset
136+ imdb_tokenized .set_format ("torch" , columns = ["input_ids" , "attention_mask" , "label" ])
137+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
138+
139+ #########################################
140+ ### 3 Set Up DataLoaders
141+ #########################################
142+
143+ train_dataset = IMDBDataset (imdb_tokenized , partition_key = "train" )
144+ val_dataset = IMDBDataset (imdb_tokenized , partition_key = "validation" )
145+ test_dataset = IMDBDataset (imdb_tokenized , partition_key = "test" )
146+
147+ train_loader = DataLoader (
148+ dataset = train_dataset ,
149+ batch_size = 12 ,
150+ shuffle = True ,
151+ num_workers = 4 ,
152+ drop_last = True ,
153+ )
154+
155+ val_loader = DataLoader (
156+ dataset = val_dataset ,
157+ batch_size = 12 ,
158+ num_workers = 2 ,
159+ drop_last = True ,
160+ )
161+
162+ test_loader = DataLoader (
163+ dataset = test_dataset ,
164+ batch_size = 12 ,
165+ num_workers = 2 ,
166+ drop_last = True ,
167+ )
168+
169+ #########################################
170+ ### 4 Initializing the Model
171+ #########################################
172+
173+ model = AutoModelForSequenceClassification .from_pretrained (
174+ "distilbert-base-uncased" , num_labels = 2
175+ )
176+
177+ model .to (device )
178+ optimizer = torch .optim .Adam (model .parameters (), lr = 5e-5 )
179+
180+ #########################################
181+ ### 5 Finetuning
182+ #########################################
183+
184+ start = time .time ()
185+ train (
186+ num_epochs = 3 ,
187+ model = model ,
188+ optimizer = optimizer ,
189+ train_loader = train_loader ,
190+ val_loader = val_loader ,
191+ device = device ,
192+ )
193+
194+ end = time .time ()
195+ elapsed = end - start
196+ print (f"Time elapsed { elapsed / 60 :.2f} min" )
197+
198+ with torch .no_grad ():
199+ model .eval ()
200+ test_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
201+ for batch in test_loader :
202+ for s in ["input_ids" , "attention_mask" , "label" ]:
203+ batch [s ] = batch [s ].to (device )
204+ outputs = model (
205+ batch ["input_ids" ],
206+ attention_mask = batch ["attention_mask" ],
207+ labels = batch ["label" ],
208+ )
209+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
210+ test_acc .update (predicted_labels , batch ["label" ])
211+
212+ print (f"Test accuracy { test_acc .compute ()* 100 :.2f} %" )
0 commit comments