3
3
from ..static_funcs_training import make_iterable_verbose
4
4
from ..models .ensemble import EnsembleKGE
5
5
from typing import Tuple
6
+ import time
6
7
7
8
def extract_input_outputs (z : list , device = None ):
8
9
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
@@ -23,67 +24,99 @@ def extract_input_outputs(z: list, device=None):
23
24
else :
24
25
raise ValueError ('Unexpected batch shape..' )
25
26
26
- def find_good_batch_size (train_loader ,ensemble_model , max_available_gpu_memory :float = 0.1 ):
27
+
28
+ def find_good_batch_size (train_loader ,tp_ensemble_model ):
27
29
# () Initial batch size
28
- batch_size = train_loader .batch_size
29
- if batch_size >= len (train_loader .dataset ):
30
- return batch_size
31
- first_batch_size = train_loader .batch_size
32
- num_datapoints = len (train_loader .dataset )
33
- print (f"Increment the batch size by { first_batch_size } until the Free/Total GPU memory is reached to { 1 - max_available_gpu_memory } or batch_size={ num_datapoints } is achieved." )
34
- while True :
35
- # () Initialize a dataloader with a current batch_size
36
- train_dataloaders = torch .utils .data .DataLoader (train_loader .dataset ,
37
- batch_size = batch_size ,
38
- shuffle = True ,
39
- sampler = None ,
40
- batch_sampler = None ,
41
- num_workers = 0 ,
42
- collate_fn = train_loader .dataset .collate_fn ,
43
- pin_memory = False , drop_last = False ,
44
- timeout = 0 ,
45
- worker_init_fn = None ,
46
- persistent_workers = False )
47
- loss = None
48
- avg_global_free_memory = []
49
- for i , z in enumerate (train_dataloaders ):
50
- loss = forward_backward_update_loss (z ,ensemble_model )
51
- global_free_memory , total_memory = torch .cuda .mem_get_info ()
52
- avg_global_free_memory .append (global_free_memory / total_memory )
53
- if i == 3 :
54
- break
55
- avg_global_free_memory = sum (avg_global_free_memory )/ len (avg_global_free_memory )
56
- print (f"Random Batch Loss: { loss } \t Free/Total GPU Memory: { avg_global_free_memory } \t Batch Size:{ batch_size } " )
57
- if avg_global_free_memory > max_available_gpu_memory and batch_size < num_datapoints :
58
- if batch_size + first_batch_size <= num_datapoints :
59
- batch_size += first_batch_size
60
- else :
61
- batch_size = num_datapoints
30
+ initial_batch_size = train_loader .batch_size
31
+ training_dataset_size = len (train_loader .dataset )
32
+ if initial_batch_size >= training_dataset_size :
33
+ return training_dataset_size , None
34
+ print ("Number of training data points:" ,training_dataset_size )
35
+
36
+ def increase_batch_size_until_cuda_out_of_memory (ensemble_model , train_loader , batch_size ,delta : int = None ):
37
+ assert delta is not None , "delta must be positive integer"
38
+ batch_sizes_and_mem_usages = []
39
+ num_datapoints = len (train_loader .dataset )
40
+ try :
41
+ while True :
42
+ start_time = time .time ()
43
+ # () Initialize a dataloader with a current batch_size
44
+ train_dataloaders = torch .utils .data .DataLoader (train_loader .dataset ,
45
+ batch_size = batch_size ,
46
+ shuffle = True ,
47
+ sampler = None ,
48
+ batch_sampler = None ,
49
+ num_workers = train_loader .num_workers ,
50
+ collate_fn = train_loader .dataset .collate_fn ,
51
+ pin_memory = False ,
52
+ drop_last = False ,
53
+ timeout = 0 ,
54
+ worker_init_fn = None ,
55
+ persistent_workers = False )
56
+
57
+ batch_loss = None
58
+ for i , batch_of_training_data in enumerate (train_dataloaders ):
59
+ batch_loss = forward_backward_update_loss (batch_of_training_data , ensemble_model )
60
+ break
61
+
62
+ global_free_memory , total_memory = torch .cuda .mem_get_info (device = "cuda:0" )
63
+ percentage_used_gpu_memory = (total_memory - global_free_memory ) / total_memory
64
+ rt = time .time ()- start_time
65
+ print (f"Random Batch Loss: { batch_loss :0.4} \t GPU Usage: { percentage_used_gpu_memory :0.3} \t Runtime: { rt :.3f} \t Batch Size: { batch_size } " )
66
+
67
+ global_free_memory , total_memory = torch .cuda .mem_get_info (device = "cuda:0" )
68
+ percentage_used_gpu_memory = (total_memory - global_free_memory ) / total_memory
69
+
70
+ # Store the batch size and the runtime
71
+ batch_sizes_and_mem_usages .append ((batch_size , rt ))
72
+
73
+ if batch_size < num_datapoints :
74
+ # Increase the batch size.
75
+ batch_size += int (batch_size / delta )
76
+ else :
77
+ return batch_sizes_and_mem_usages ,True
78
+
79
+ except torch .OutOfMemoryError :
80
+ print ("torch.OutOfMemoryError caught!" )
81
+ return batch_sizes_and_mem_usages , False
82
+
83
+ history_batch_sizes_and_mem_usages = []
84
+ batch_size = initial_batch_size
85
+
86
+ for delta in range (1 ,5 ,1 ):
87
+ result ,flag = increase_batch_size_until_cuda_out_of_memory (tp_ensemble_model , train_loader , batch_size ,delta = delta )
88
+
89
+ history_batch_sizes_and_mem_usages .extend (result )
90
+
91
+ if flag :
92
+ batch_size , batch_rt = history_batch_sizes_and_mem_usages [- 1 ]
93
+ else :
94
+ # CUDA ERROR Observed
95
+ batch_size , batch_rt = history_batch_sizes_and_mem_usages [- 2 ]
96
+
97
+ if batch_size >= training_dataset_size :
98
+ batch_size = training_dataset_size
99
+ break
62
100
else :
63
- assert batch_size <= num_datapoints
64
- if batch_size == num_datapoints :
65
- print ("Batch size equals to the training dataset size" )
66
- else :
67
- print (f"Max GPU memory used\t Free/Total GPU Memory:{ avg_global_free_memory } " )
68
- return batch_size
69
-
70
- def forward_backward_update_loss (z :Tuple , ensemble_model ):
71
- # () Get the i-th batch of data points.
101
+ continue
102
+
103
+ return batch_size , batch_rt
104
+
105
+
106
+ def forward_backward_update_loss (z :Tuple , ensemble_model )-> float :
107
+ # () Get a random batch of data points (z).
72
108
x_batch , y_batch = extract_input_outputs (z )
73
- # () Move the batch of labels into the master GPU : GPU-0
109
+ # () Move the batch of labels into the master GPU : GPU-0.
74
110
y_batch = y_batch .to ("cuda:0" )
75
- # () Forward Pass on the batch. Yhat located on the master GPU.
111
+ # () Forward pas on the batch of input data points (yhat on the master GPU) .
76
112
yhat = ensemble_model (x_batch )
77
- # () Compute the loss
113
+ # () Compute the loss.
78
114
loss = torch .nn .functional .binary_cross_entropy_with_logits (yhat , y_batch )
79
115
# () Compute the gradient of the loss w.r.t. parameters.
80
116
loss .backward ()
81
117
# () Parameter update.
82
118
ensemble_model .step ()
83
- # () Report the batch and epoch losses.
84
- batch_loss = loss .item ()
85
- # () Accumulate batch loss
86
- return batch_loss
119
+ return loss .item ()
87
120
88
121
class TensorParallel (AbstractTrainer ):
89
122
def __init__ (self , args , callbacks ):
@@ -103,12 +136,14 @@ def fit(self, *args, **kwargs):
103
136
self .on_fit_start (self , ensemble_model )
104
137
# () Sanity checking
105
138
assert torch .cuda .device_count ()== len (ensemble_model )
106
- # ()
139
+ # () Get DataLoader
107
140
train_dataloader = kwargs ['train_dataloaders' ]
108
- # ()
141
+ # () Find a batch size so that available GPU memory is *almost* fully used.
109
142
if self .attributes .auto_batch_finding :
143
+ batch_size , batch_rt = find_good_batch_size (train_dataloader , ensemble_model )
144
+
110
145
train_dataloader = torch .utils .data .DataLoader (train_dataloader .dataset ,
111
- batch_size = find_good_batch_size ( train_dataloader , ensemble_model ) ,
146
+ batch_size = batch_size ,
112
147
shuffle = True ,
113
148
sampler = None ,
114
149
batch_sampler = None ,
@@ -119,29 +154,38 @@ def fit(self, *args, **kwargs):
119
154
timeout = 0 ,
120
155
worker_init_fn = None ,
121
156
persistent_workers = False )
157
+ if batch_rt is not None :
158
+ expected_training_time = batch_rt * len (train_dataloader ) * self .attributes .num_epochs
159
+ print (f"Exp.Training Runtime: { expected_training_time / 60 :.3f} in mins\t |\t Batch Size:{ batch_size } \t |\t Batch RT:{ batch_rt :.3f} \t |\t # of batches:{ len (train_dataloader )} \t |\t # of epochs:{ self .attributes .num_epochs } " )
122
160
161
+ # () Number of batches to reach a single epoch.
123
162
num_of_batches = len (train_dataloader )
124
163
# () Start training.
125
164
for epoch in (tqdm_bar := make_iterable_verbose (range (self .attributes .num_epochs ),
126
165
verbose = True , position = 0 , leave = True )):
166
+ # () Accumulate the batch losses.
127
167
epoch_loss = 0
128
168
# () Iterate over batches.
129
169
for i , z in enumerate (train_dataloader ):
170
+ # () Forward, Loss, Backward, and Update on a given batch of data points.
130
171
batch_loss = forward_backward_update_loss (z ,ensemble_model )
172
+ # () Accumulate the batch losses to compute the epoch loss.
131
173
epoch_loss += batch_loss
132
-
174
+ # if verbose=TRue, show info.
133
175
if hasattr (tqdm_bar , 'set_description_str' ):
134
176
tqdm_bar .set_description_str (f"Epoch:{ epoch + 1 } " )
135
177
if i > 0 :
136
178
tqdm_bar .set_postfix_str (
137
179
f"batch={ i } | { num_of_batches } , loss_step={ batch_loss :.5f} , loss_epoch={ epoch_loss / i :.5f} " )
138
180
else :
139
181
tqdm_bar .set_postfix_str (f"loss_step={ batch_loss :.5f} , loss_epoch={ batch_loss :.5f} " )
182
+ # Store the epoch loss
140
183
ensemble_model .loss_history .append (epoch_loss )
141
-
184
+ # Run on_fit_end callbacks after the training is done.
142
185
self .on_fit_end (self , ensemble_model )
143
186
# TODO: Later, maybe we should write a callback to save the models in disk
144
187
return ensemble_model
188
+
145
189
"""
146
190
147
191
def batchwisefit(self, *args, **kwargs):
@@ -242,4 +286,4 @@ def torch_buggy_fit(self, *args, **kwargs):
242
286
torch.distributed.destroy_process_group()
243
287
# () .
244
288
self.on_fit_end(self, model)
245
- """
289
+ """
0 commit comments