Skip to content

Commit 100bc1d

Browse files
committed
reorganizing and cleaning
1 parent 00629a5 commit 100bc1d

20 files changed

+96554
-0
lines changed

BERT/BERTrain-Copy2.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import os
2+
import torch
3+
from torch.utils.data import Dataset
4+
import pandas as pd
5+
from transformers import BertTokenizer, BertForMaskedLM
6+
from nltk.tokenize import sent_tokenize
7+
import nltk
8+
import time
9+
import datetime
10+
import pickle
11+
import torch
12+
from torch.utils.data import DataLoader
13+
from torch.optim import AdamW
14+
import random
15+
from transformers import get_linear_schedule_with_warmup
16+
import nltk
17+
from nltk.translate.bleu_score import SmoothingFunction
18+
print('here')
19+
if torch.cuda.is_available():
20+
device = torch.device("cuda")
21+
22+
print('There are %d GPU(s) available.' % torch.cuda.device_count())
23+
24+
print('We will use the GPU:', torch.cuda.get_device_name(0))
25+
else:
26+
device = torch.device('cpu')
27+
print("cpu")
28+
29+
30+
path = "/global/cscratch1/sd/ajaybati/pickles/DSdata128/"
31+
32+
print("passed")
33+
train_dataloader = torch.load(path+"train_dataloaderDS128.pickle")
34+
print("pass 2")
35+
validation_dataloader = torch.load(path+"validation_dataloaderDS128.pickle") #load when validating
36+
print("done")
37+
38+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
39+
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
40+
model = model.to(device)
41+
print("done")
42+
43+
optimizer = AdamW(model.parameters(),
44+
lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
45+
eps = 1e-8 # args.adam_epsilon - default is 1e-8.
46+
)
47+
48+
training_stats = []
49+
50+
# Number of training epochs. The BERT authors recommend between 2 and 4.
51+
# We chose to run for 4, but we'll see later that this may be over-fitting the
52+
# training data.
53+
EPOCHS = 4
54+
55+
# Total number of training steps is [number of batches] x [number of epochs].
56+
# (Note that this is not the same as the number of training samples).
57+
total_steps = len(train_dataloader) * EPOCHS
58+
59+
# Create the learning rate scheduler.
60+
scheduler = get_linear_schedule_with_warmup(optimizer,
61+
num_warmup_steps = 0, # Default value in run_glue.py
62+
num_training_steps = len(train_dataloader)*EPOCHS)
63+
print("done")
64+
65+
66+
67+
def getSent_pred(prediction,real_labels):
68+
sentlist_real = []
69+
sep_list = []
70+
for sent2 in real_labels:
71+
tokenized = tokenizer.convert_ids_to_tokens(sent2)
72+
sep = tokenized.index('[SEP]')
73+
sep_list.append(sep)
74+
sentlist_real.append(tokenized[1:sep])
75+
76+
77+
sentlist_ids = []
78+
sentlist = []
79+
for sent in prediction:
80+
word_list = []
81+
for word in sent:
82+
word_list.append(torch.argmax(word))
83+
sentlist_ids.append(word_list)
84+
85+
for index,sent in enumerate(sentlist_ids):
86+
sentlist.append(tokenizer.convert_ids_to_tokens(sent)[1:sep_list[index]])
87+
return sentlist,sentlist_real
88+
89+
def bleu(p,r):
90+
smoothie = SmoothingFunction().method2
91+
bleu_list = []
92+
for index in range(len(p)):
93+
BLEUscore = nltk.translate.bleu_score.sentence_bleu(p[index],r[index],smoothing_function=smoothie)
94+
bleu_list.append(BLEUscore)
95+
return sum(bleu_list) / len(bleu_list)
96+
97+
def format_time(elapsed):
98+
elapsed_rounded = int(round((elapsed)))
99+
100+
return str(datetime.timedelta(seconds=elapsed_rounded))
101+
102+
def calc_accuracy(prediction, real_labels, mask_indices):
103+
score = 0
104+
total = 0
105+
for step,sent in enumerate(mask_indices):
106+
if list(sent).count(0)!=40:
107+
for mask in sent:
108+
if int(mask)!=0:
109+
predicted_index = int(torch.argmax(prediction[step,int(mask)]))
110+
actual = int(real_labels[step][int(mask)])
111+
if bool(predicted_index==actual):
112+
score+=1
113+
total+=1
114+
else:
115+
pass
116+
117+
else:
118+
pass
119+
120+
p,r = getSent_pred(prediction,real_labels)
121+
122+
123+
accuracy = score/total
124+
try:
125+
bscore = bleu(p,r)
126+
except:
127+
bscore = "Unfortunately, not possible"
128+
return accuracy, bscore
129+
print("done")
130+
131+
132+
133+
# ==========================================================================================
134+
135+
#in general remember that there are some sentence where no masks exist
136+
seed_val = 42
137+
138+
random.seed(seed_val)
139+
torch.manual_seed(seed_val)
140+
141+
break_factor = False
142+
143+
# Measure the total training time for the whole run.
144+
total_t0 = time.time()
145+
146+
print("starting...")
147+
# For each epoch...
148+
for epoch_i in range(0, EPOCHS):
149+
print("")
150+
print('======== Epoch {:} / {:} ========'.format(epoch_i+5, EPOCHS+4))
151+
print('Training...')
152+
checkpoint = torch.load("/global/cscratch1/sd/ajaybati/model_ckptDS"+str(epoch_i)+".pickle")
153+
model.load_state_dict(checkpoint['model_state_dict'])
154+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
155+
epoch = checkpoint['epoch']+1
156+
total_train_loss = 0
157+
step_resume = 0
158+
training_stats = checkpoint['training_stats']
159+
print('step: ',step_resume, 'total loss: ',total_train_loss, 'epoch: ', epoch)
160+
161+
162+
# Measure how long the training epoch takes.
163+
t0 = time.time()
164+
165+
model.train()
166+
167+
# For each batch of training data...
168+
for step, batch in enumerate(train_dataloader):
169+
b_input_ids = batch[0].to(device)
170+
b_input_mask = batch[1].to(device)
171+
b_input_ids_real = batch[2].to(device)
172+
b_input_mask_ids = batch[3]
173+
174+
model.zero_grad()
175+
176+
loss, predictions = model(b_input_ids,
177+
attention_mask=b_input_mask,
178+
masked_lm_labels=b_input_ids_real)
179+
180+
181+
total_train_loss += float(loss)
182+
183+
if step % 40 == 0 and not step == 0:
184+
elapsed = format_time(time.time() - t0)
185+
print(' Batch {:>5,} of {:>5,}. Percent done: {:}% Elapsed: {:}.'.format(step, len(train_dataloader),step/len(train_dataloader)*100, elapsed))
186+
print("*"*50)
187+
print(loss)
188+
print("*"*50)
189+
acc, bscore = calc_accuracy(predictions, b_input_ids_real, b_input_mask_ids)
190+
print("accuracy: ", acc, "bleu: ", bscore)
191+
print("="*100)
192+
193+
loss.backward()
194+
195+
#stop exploding gradients problem.
196+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
197+
198+
optimizer.step()
199+
200+
scheduler.step()
201+
202+
# Calculate the average loss over all of the batches.
203+
avg_train_loss = total_train_loss / len(train_dataloader)
204+
205+
# Measure how long this epoch took.
206+
# training_time = format_time(time.time() - t0)
207+
208+
print("")
209+
print(" Average training loss: {0:.2f}".format(avg_train_loss))
210+
# print(" Training epoch took: {:}".format(training_time))
211+
212+
# ========================================
213+
# Validation
214+
# ========================================
215+
# After the completion of each training epoch, measure our performance on
216+
# our validation set.
217+
print("")
218+
print("Running Validation...")
219+
220+
t0 = time.time()
221+
222+
# Put the model in evaluation mode--the dropout layers behave differently
223+
# during evaluation.
224+
model.eval()
225+
226+
# Tracking variables
227+
total_eval_loss = 0
228+
nb_eval_steps = 0
229+
total_eval_accuracy = 0
230+
total_bleuscore = 0
231+
232+
233+
# Evaluate data for one epoch
234+
for step,batch in enumerate(validation_dataloader):
235+
236+
# Unpack this training batch from our dataloader.
237+
#
238+
# As we unpack the batch, we'll also copy each tensor to the GPU using
239+
# the `to` method.
240+
#
241+
# `batch` contains three pytorch tensors:
242+
# [0]: input ids
243+
# [1]: attention masks
244+
# [2]: real ids
245+
# [3]: mask ids for comparison
246+
b_input_ids = batch[0].to(device)
247+
b_input_mask = batch[1].to(device)
248+
b_input_ids_real = batch[2].to(device)
249+
b_input_mask_ids = batch[3]
250+
251+
252+
with torch.no_grad():
253+
254+
(loss, logits) = model(b_input_ids,
255+
attention_mask=b_input_mask,
256+
masked_lm_labels=b_input_ids)
257+
258+
if step % 40 == 0 and not step == 0:
259+
elapsed = format_time(time.time() - t0)
260+
print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(validation_dataloader), elapsed))
261+
print("*"*50)
262+
print(loss)
263+
print("*"*50)
264+
acc, bscore = calc_accuracy(logits, b_input_ids_real, b_input_mask_ids)
265+
print("accuracy: ", acc, "bleu: ", bscore)
266+
print("="*100)
267+
268+
# Accumulate the validation loss.
269+
total_eval_loss += loss.item()
270+
accuracy, bleuscore = calc_accuracy(logits, b_input_ids_real, b_input_mask_ids)
271+
total_eval_accuracy += accuracy
272+
total_bleuscore += bleuscore
273+
274+
# Report the final accuracy for this validation run.
275+
avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
276+
print(" Accuracy: {0:.2f}".format(avg_val_accuracy))
277+
278+
avg_bleuscore = total_bleuscore / len(validation_dataloader)
279+
avg_val_loss = total_eval_loss / len(validation_dataloader)
280+
281+
validation_time = format_time(time.time() - t0)
282+
avg_train_loss = total_train_loss / len(validation_dataloader)
283+
284+
print(" Validation Loss: {0:.2f}".format(avg_val_loss))
285+
print(" Validation took: {:}".format(validation_time))
286+
training_stats.append(
287+
{
288+
'Avg Accuracy': avg_val_accuracy,
289+
'Bleu Score': avg_bleuscore,
290+
'Training Loss': avg_train_loss,
291+
'Valid. Loss': avg_val_loss,
292+
'Validation Time': validation_time
293+
}
294+
)
295+
296+
torch.save({
297+
'epoch': epoch_i+4,
298+
'model_state_dict': model.state_dict(),
299+
'optimizer_state_dict': optimizer.state_dict(),
300+
'total_train_loss': total_train_loss,
301+
'step': len(train_dataloader),
302+
'training_stats':training_stats}, "/global/cscratch1/sd/ajaybati/model_ckptDS"+str(epoch_i+1)+".pickle")
303+
print(training_stats)
304+
305+
print("")
306+
307+
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-t0)))
308+
309+
print("done completely")

0 commit comments

Comments
 (0)