|
| 1 | +import os |
| 2 | +import torch |
| 3 | +from torch.utils.data import Dataset |
| 4 | +import pandas as pd |
| 5 | +from transformers import BertTokenizer, BertForMaskedLM |
| 6 | +from nltk.tokenize import sent_tokenize |
| 7 | +import nltk |
| 8 | +import time |
| 9 | +import datetime |
| 10 | +import pickle |
| 11 | +import torch |
| 12 | +from torch.utils.data import DataLoader |
| 13 | +from torch.optim import AdamW |
| 14 | +import random |
| 15 | +from transformers import get_linear_schedule_with_warmup |
| 16 | +import nltk |
| 17 | +from nltk.translate.bleu_score import SmoothingFunction |
| 18 | +print('here') |
| 19 | +if torch.cuda.is_available(): |
| 20 | + device = torch.device("cuda") |
| 21 | + |
| 22 | + print('There are %d GPU(s) available.' % torch.cuda.device_count()) |
| 23 | + |
| 24 | + print('We will use the GPU:', torch.cuda.get_device_name(0)) |
| 25 | +else: |
| 26 | + device = torch.device('cpu') |
| 27 | + print("cpu") |
| 28 | + |
| 29 | + |
| 30 | +path = "/global/cscratch1/sd/ajaybati/pickles/DSdata128/" |
| 31 | + |
| 32 | +print("passed") |
| 33 | +train_dataloader = torch.load(path+"train_dataloaderDS128.pickle") |
| 34 | +print("pass 2") |
| 35 | +validation_dataloader = torch.load(path+"validation_dataloaderDS128.pickle") #load when validating |
| 36 | +print("done") |
| 37 | + |
| 38 | +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| 39 | +model = BertForMaskedLM.from_pretrained('bert-base-uncased') |
| 40 | +model = model.to(device) |
| 41 | +print("done") |
| 42 | + |
| 43 | +optimizer = AdamW(model.parameters(), |
| 44 | + lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5 |
| 45 | + eps = 1e-8 # args.adam_epsilon - default is 1e-8. |
| 46 | + ) |
| 47 | + |
| 48 | +training_stats = [] |
| 49 | + |
| 50 | +# Number of training epochs. The BERT authors recommend between 2 and 4. |
| 51 | +# We chose to run for 4, but we'll see later that this may be over-fitting the |
| 52 | +# training data. |
| 53 | +EPOCHS = 4 |
| 54 | + |
| 55 | +# Total number of training steps is [number of batches] x [number of epochs]. |
| 56 | +# (Note that this is not the same as the number of training samples). |
| 57 | +total_steps = len(train_dataloader) * EPOCHS |
| 58 | + |
| 59 | +# Create the learning rate scheduler. |
| 60 | +scheduler = get_linear_schedule_with_warmup(optimizer, |
| 61 | + num_warmup_steps = 0, # Default value in run_glue.py |
| 62 | + num_training_steps = len(train_dataloader)*EPOCHS) |
| 63 | +print("done") |
| 64 | + |
| 65 | + |
| 66 | + |
| 67 | +def getSent_pred(prediction,real_labels): |
| 68 | + sentlist_real = [] |
| 69 | + sep_list = [] |
| 70 | + for sent2 in real_labels: |
| 71 | + tokenized = tokenizer.convert_ids_to_tokens(sent2) |
| 72 | + sep = tokenized.index('[SEP]') |
| 73 | + sep_list.append(sep) |
| 74 | + sentlist_real.append(tokenized[1:sep]) |
| 75 | + |
| 76 | + |
| 77 | + sentlist_ids = [] |
| 78 | + sentlist = [] |
| 79 | + for sent in prediction: |
| 80 | + word_list = [] |
| 81 | + for word in sent: |
| 82 | + word_list.append(torch.argmax(word)) |
| 83 | + sentlist_ids.append(word_list) |
| 84 | + |
| 85 | + for index,sent in enumerate(sentlist_ids): |
| 86 | + sentlist.append(tokenizer.convert_ids_to_tokens(sent)[1:sep_list[index]]) |
| 87 | + return sentlist,sentlist_real |
| 88 | + |
| 89 | +def bleu(p,r): |
| 90 | + smoothie = SmoothingFunction().method2 |
| 91 | + bleu_list = [] |
| 92 | + for index in range(len(p)): |
| 93 | + BLEUscore = nltk.translate.bleu_score.sentence_bleu(p[index],r[index],smoothing_function=smoothie) |
| 94 | + bleu_list.append(BLEUscore) |
| 95 | + return sum(bleu_list) / len(bleu_list) |
| 96 | + |
| 97 | +def format_time(elapsed): |
| 98 | + elapsed_rounded = int(round((elapsed))) |
| 99 | + |
| 100 | + return str(datetime.timedelta(seconds=elapsed_rounded)) |
| 101 | + |
| 102 | +def calc_accuracy(prediction, real_labels, mask_indices): |
| 103 | + score = 0 |
| 104 | + total = 0 |
| 105 | + for step,sent in enumerate(mask_indices): |
| 106 | + if list(sent).count(0)!=40: |
| 107 | + for mask in sent: |
| 108 | + if int(mask)!=0: |
| 109 | + predicted_index = int(torch.argmax(prediction[step,int(mask)])) |
| 110 | + actual = int(real_labels[step][int(mask)]) |
| 111 | + if bool(predicted_index==actual): |
| 112 | + score+=1 |
| 113 | + total+=1 |
| 114 | + else: |
| 115 | + pass |
| 116 | + |
| 117 | + else: |
| 118 | + pass |
| 119 | + |
| 120 | + p,r = getSent_pred(prediction,real_labels) |
| 121 | + |
| 122 | + |
| 123 | + accuracy = score/total |
| 124 | + try: |
| 125 | + bscore = bleu(p,r) |
| 126 | + except: |
| 127 | + bscore = "Unfortunately, not possible" |
| 128 | + return accuracy, bscore |
| 129 | +print("done") |
| 130 | + |
| 131 | + |
| 132 | + |
| 133 | +# ========================================================================================== |
| 134 | + |
| 135 | +#in general remember that there are some sentence where no masks exist |
| 136 | +seed_val = 42 |
| 137 | + |
| 138 | +random.seed(seed_val) |
| 139 | +torch.manual_seed(seed_val) |
| 140 | + |
| 141 | +break_factor = False |
| 142 | + |
| 143 | +# Measure the total training time for the whole run. |
| 144 | +total_t0 = time.time() |
| 145 | + |
| 146 | +print("starting...") |
| 147 | +# For each epoch... |
| 148 | +for epoch_i in range(0, EPOCHS): |
| 149 | + print("") |
| 150 | + print('======== Epoch {:} / {:} ========'.format(epoch_i+5, EPOCHS+4)) |
| 151 | + print('Training...') |
| 152 | + checkpoint = torch.load("/global/cscratch1/sd/ajaybati/model_ckptDS"+str(epoch_i)+".pickle") |
| 153 | + model.load_state_dict(checkpoint['model_state_dict']) |
| 154 | + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| 155 | + epoch = checkpoint['epoch']+1 |
| 156 | + total_train_loss = 0 |
| 157 | + step_resume = 0 |
| 158 | + training_stats = checkpoint['training_stats'] |
| 159 | + print('step: ',step_resume, 'total loss: ',total_train_loss, 'epoch: ', epoch) |
| 160 | + |
| 161 | + |
| 162 | + # Measure how long the training epoch takes. |
| 163 | + t0 = time.time() |
| 164 | + |
| 165 | + model.train() |
| 166 | + |
| 167 | + # For each batch of training data... |
| 168 | + for step, batch in enumerate(train_dataloader): |
| 169 | + b_input_ids = batch[0].to(device) |
| 170 | + b_input_mask = batch[1].to(device) |
| 171 | + b_input_ids_real = batch[2].to(device) |
| 172 | + b_input_mask_ids = batch[3] |
| 173 | + |
| 174 | + model.zero_grad() |
| 175 | + |
| 176 | + loss, predictions = model(b_input_ids, |
| 177 | + attention_mask=b_input_mask, |
| 178 | + masked_lm_labels=b_input_ids_real) |
| 179 | + |
| 180 | + |
| 181 | + total_train_loss += float(loss) |
| 182 | + |
| 183 | + if step % 40 == 0 and not step == 0: |
| 184 | + elapsed = format_time(time.time() - t0) |
| 185 | + print(' Batch {:>5,} of {:>5,}. Percent done: {:}% Elapsed: {:}.'.format(step, len(train_dataloader),step/len(train_dataloader)*100, elapsed)) |
| 186 | + print("*"*50) |
| 187 | + print(loss) |
| 188 | + print("*"*50) |
| 189 | + acc, bscore = calc_accuracy(predictions, b_input_ids_real, b_input_mask_ids) |
| 190 | + print("accuracy: ", acc, "bleu: ", bscore) |
| 191 | + print("="*100) |
| 192 | + |
| 193 | + loss.backward() |
| 194 | + |
| 195 | + #stop exploding gradients problem. |
| 196 | + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| 197 | + |
| 198 | + optimizer.step() |
| 199 | + |
| 200 | + scheduler.step() |
| 201 | + |
| 202 | + # Calculate the average loss over all of the batches. |
| 203 | + avg_train_loss = total_train_loss / len(train_dataloader) |
| 204 | + |
| 205 | + # Measure how long this epoch took. |
| 206 | +# training_time = format_time(time.time() - t0) |
| 207 | + |
| 208 | + print("") |
| 209 | + print(" Average training loss: {0:.2f}".format(avg_train_loss)) |
| 210 | +# print(" Training epoch took: {:}".format(training_time)) |
| 211 | + |
| 212 | + # ======================================== |
| 213 | + # Validation |
| 214 | + # ======================================== |
| 215 | + # After the completion of each training epoch, measure our performance on |
| 216 | + # our validation set. |
| 217 | + print("") |
| 218 | + print("Running Validation...") |
| 219 | + |
| 220 | + t0 = time.time() |
| 221 | + |
| 222 | + # Put the model in evaluation mode--the dropout layers behave differently |
| 223 | + # during evaluation. |
| 224 | + model.eval() |
| 225 | + |
| 226 | + # Tracking variables |
| 227 | + total_eval_loss = 0 |
| 228 | + nb_eval_steps = 0 |
| 229 | + total_eval_accuracy = 0 |
| 230 | + total_bleuscore = 0 |
| 231 | + |
| 232 | + |
| 233 | + # Evaluate data for one epoch |
| 234 | + for step,batch in enumerate(validation_dataloader): |
| 235 | + |
| 236 | + # Unpack this training batch from our dataloader. |
| 237 | + # |
| 238 | + # As we unpack the batch, we'll also copy each tensor to the GPU using |
| 239 | + # the `to` method. |
| 240 | + # |
| 241 | + # `batch` contains three pytorch tensors: |
| 242 | + # [0]: input ids |
| 243 | + # [1]: attention masks |
| 244 | + # [2]: real ids |
| 245 | + # [3]: mask ids for comparison |
| 246 | + b_input_ids = batch[0].to(device) |
| 247 | + b_input_mask = batch[1].to(device) |
| 248 | + b_input_ids_real = batch[2].to(device) |
| 249 | + b_input_mask_ids = batch[3] |
| 250 | + |
| 251 | + |
| 252 | + with torch.no_grad(): |
| 253 | + |
| 254 | + (loss, logits) = model(b_input_ids, |
| 255 | + attention_mask=b_input_mask, |
| 256 | + masked_lm_labels=b_input_ids) |
| 257 | + |
| 258 | + if step % 40 == 0 and not step == 0: |
| 259 | + elapsed = format_time(time.time() - t0) |
| 260 | + print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(validation_dataloader), elapsed)) |
| 261 | + print("*"*50) |
| 262 | + print(loss) |
| 263 | + print("*"*50) |
| 264 | + acc, bscore = calc_accuracy(logits, b_input_ids_real, b_input_mask_ids) |
| 265 | + print("accuracy: ", acc, "bleu: ", bscore) |
| 266 | + print("="*100) |
| 267 | + |
| 268 | + # Accumulate the validation loss. |
| 269 | + total_eval_loss += loss.item() |
| 270 | + accuracy, bleuscore = calc_accuracy(logits, b_input_ids_real, b_input_mask_ids) |
| 271 | + total_eval_accuracy += accuracy |
| 272 | + total_bleuscore += bleuscore |
| 273 | + |
| 274 | + # Report the final accuracy for this validation run. |
| 275 | + avg_val_accuracy = total_eval_accuracy / len(validation_dataloader) |
| 276 | + print(" Accuracy: {0:.2f}".format(avg_val_accuracy)) |
| 277 | + |
| 278 | + avg_bleuscore = total_bleuscore / len(validation_dataloader) |
| 279 | + avg_val_loss = total_eval_loss / len(validation_dataloader) |
| 280 | + |
| 281 | + validation_time = format_time(time.time() - t0) |
| 282 | + avg_train_loss = total_train_loss / len(validation_dataloader) |
| 283 | + |
| 284 | + print(" Validation Loss: {0:.2f}".format(avg_val_loss)) |
| 285 | + print(" Validation took: {:}".format(validation_time)) |
| 286 | + training_stats.append( |
| 287 | + { |
| 288 | + 'Avg Accuracy': avg_val_accuracy, |
| 289 | + 'Bleu Score': avg_bleuscore, |
| 290 | + 'Training Loss': avg_train_loss, |
| 291 | + 'Valid. Loss': avg_val_loss, |
| 292 | + 'Validation Time': validation_time |
| 293 | + } |
| 294 | + ) |
| 295 | + |
| 296 | + torch.save({ |
| 297 | + 'epoch': epoch_i+4, |
| 298 | + 'model_state_dict': model.state_dict(), |
| 299 | + 'optimizer_state_dict': optimizer.state_dict(), |
| 300 | + 'total_train_loss': total_train_loss, |
| 301 | + 'step': len(train_dataloader), |
| 302 | + 'training_stats':training_stats}, "/global/cscratch1/sd/ajaybati/model_ckptDS"+str(epoch_i+1)+".pickle") |
| 303 | + print(training_stats) |
| 304 | + |
| 305 | +print("") |
| 306 | + |
| 307 | +print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-t0))) |
| 308 | + |
| 309 | +print("done completely") |
0 commit comments