3
3
from ..static_funcs_training import make_iterable_verbose
4
4
from ..models .ensemble import EnsembleKGE
5
5
from typing import Tuple
6
+ import time
6
7
7
8
def extract_input_outputs (z : list , device = None ):
8
9
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
@@ -27,59 +28,79 @@ def extract_input_outputs(z: list, device=None):
27
28
def find_good_batch_size (train_loader ,tp_ensemble_model ):
28
29
# () Initial batch size
29
30
initial_batch_size = train_loader .batch_size
30
- if initial_batch_size >= len (train_loader .dataset ):
31
- return initial_batch_size
31
+ training_dataset_size = len (train_loader .dataset )
32
+ if initial_batch_size >= training_dataset_size :
33
+ return training_dataset_size , None
34
+ print ("Number of training data points:" ,training_dataset_size )
32
35
33
- def increase_batch_size_until_cuda_out_of_memory (ensemble_model , train_loader , batch_size , delta : int = None ):
36
+ def increase_batch_size_until_cuda_out_of_memory (ensemble_model , train_loader , batch_size ,delta : int = None ):
34
37
assert delta is not None , "delta must be positive integer"
35
38
batch_sizes_and_mem_usages = []
36
39
num_datapoints = len (train_loader .dataset )
37
40
try :
38
41
while True :
42
+ start_time = time .time ()
39
43
# () Initialize a dataloader with a current batch_size
40
44
train_dataloaders = torch .utils .data .DataLoader (train_loader .dataset ,
41
45
batch_size = batch_size ,
42
46
shuffle = True ,
43
47
sampler = None ,
44
48
batch_sampler = None ,
45
- num_workers = 0 ,
49
+ num_workers = train_loader . num_workers ,
46
50
collate_fn = train_loader .dataset .collate_fn ,
47
51
pin_memory = False ,
48
52
drop_last = False ,
49
53
timeout = 0 ,
50
54
worker_init_fn = None ,
51
55
persistent_workers = False )
56
+
52
57
batch_loss = None
53
58
for i , batch_of_training_data in enumerate (train_dataloaders ):
54
59
batch_loss = forward_backward_update_loss (batch_of_training_data , ensemble_model )
55
60
break
61
+
56
62
global_free_memory , total_memory = torch .cuda .mem_get_info (device = "cuda:0" )
57
63
percentage_used_gpu_memory = (total_memory - global_free_memory ) / total_memory
58
-
59
- print (
60
- f"Random Batch Loss: { batch_loss } \t GPU Usage: { percentage_used_gpu_memory } \t Batch Size:{ batch_size } " )
64
+ rt = time .time ()- start_time
65
+ print (f"Random Batch Loss: { batch_loss :0.4} \t GPU Usage: { percentage_used_gpu_memory :0.3} \t Runtime: { rt :.3f} \t Batch Size: { batch_size } " )
61
66
62
67
global_free_memory , total_memory = torch .cuda .mem_get_info (device = "cuda:0" )
63
68
percentage_used_gpu_memory = (total_memory - global_free_memory ) / total_memory
64
-
65
- batch_sizes_and_mem_usages .append ((batch_size , percentage_used_gpu_memory ))
69
+
70
+ # Store the batch size and the runtime
71
+ batch_sizes_and_mem_usages .append ((batch_size , rt ))
72
+
66
73
if batch_size < num_datapoints :
74
+ # Increase the batch size.
67
75
batch_size += int (batch_size / delta )
68
76
else :
69
- if batch_size == num_datapoints :
70
- print ("Batch size equals to the training dataset size" )
71
- break
77
+ return batch_sizes_and_mem_usages ,True
78
+
72
79
except torch .OutOfMemoryError :
73
80
print (f"torch.OutOfMemoryError caught!" )
74
- return batch_sizes_and_mem_usages
81
+ return batch_sizes_and_mem_usages , False
75
82
76
83
history_batch_sizes_and_mem_usages = []
77
84
batch_size = initial_batch_size
78
- for delta in range (1 ,10 ,1 ):
79
- history_batch_sizes_and_mem_usages .extend (increase_batch_size_until_cuda_out_of_memory (tp_ensemble_model , train_loader , batch_size ,delta = delta ))
80
- batch_size = history_batch_sizes_and_mem_usages [- 2 ][0 ]
81
- print (f"A best found batch size:{ batch_size } in { len (history_batch_sizes_and_mem_usages )} trials. Current GPU memory usage % :{ history_batch_sizes_and_mem_usages [- 2 ][1 ]} " )
82
- return batch_size
85
+
86
+ for delta in range (1 ,5 ,1 ):
87
+ result ,flag = increase_batch_size_until_cuda_out_of_memory (tp_ensemble_model , train_loader , batch_size ,delta = delta )
88
+
89
+ history_batch_sizes_and_mem_usages .extend (result )
90
+
91
+ if flag :
92
+ batch_size , batch_rt = history_batch_sizes_and_mem_usages [- 1 ]
93
+ else :
94
+ # CUDA ERROR Observed
95
+ batch_size , batch_rt = history_batch_sizes_and_mem_usages [- 2 ]
96
+
97
+ if batch_size >= training_dataset_size :
98
+ batch_size = training_dataset_size
99
+ break
100
+ else :
101
+ continue
102
+
103
+ return batch_size , batch_rt
83
104
84
105
85
106
def forward_backward_update_loss (z :Tuple , ensemble_model )-> float :
@@ -115,12 +136,14 @@ def fit(self, *args, **kwargs):
115
136
self .on_fit_start (self , ensemble_model )
116
137
# () Sanity checking
117
138
assert torch .cuda .device_count ()== len (ensemble_model )
118
- # ()
139
+ # () Get DataLoader
119
140
train_dataloader = kwargs ['train_dataloaders' ]
120
- # ()
141
+ # () Find a batch size so that available GPU memory is *almost* fully used.
121
142
if self .attributes .auto_batch_finding :
143
+ batch_size , batch_rt = find_good_batch_size (train_dataloader , ensemble_model )
144
+
122
145
train_dataloader = torch .utils .data .DataLoader (train_dataloader .dataset ,
123
- batch_size = find_good_batch_size ( train_dataloader , ensemble_model ) ,
146
+ batch_size = batch_size ,
124
147
shuffle = True ,
125
148
sampler = None ,
126
149
batch_sampler = None ,
@@ -131,29 +154,38 @@ def fit(self, *args, **kwargs):
131
154
timeout = 0 ,
132
155
worker_init_fn = None ,
133
156
persistent_workers = False )
157
+ if batch_rt is not None :
158
+ expected_training_time = batch_rt * len (train_dataloader ) * self .attributes .num_epochs
159
+ print (f"Exp.Training Runtime: { expected_training_time / 60 :.3f} in mins\t |\t Batch Size:{ batch_size } \t |\t Batch RT:{ batch_rt :.3f} \t |\t # of batches:{ len (train_dataloader )} \t |\t # of epochs:{ self .attributes .num_epochs } " )
134
160
161
+ # () Number of batches to reach a single epoch.
135
162
num_of_batches = len (train_dataloader )
136
163
# () Start training.
137
164
for epoch in (tqdm_bar := make_iterable_verbose (range (self .attributes .num_epochs ),
138
165
verbose = True , position = 0 , leave = True )):
166
+ # () Accumulate the batch losses.
139
167
epoch_loss = 0
140
168
# () Iterate over batches.
141
169
for i , z in enumerate (train_dataloader ):
170
+ # () Forward, Loss, Backward, and Update on a given batch of data points.
142
171
batch_loss = forward_backward_update_loss (z ,ensemble_model )
172
+ # () Accumulate the batch losses to compute the epoch loss.
143
173
epoch_loss += batch_loss
144
-
174
+ # if verbose=TRue, show info.
145
175
if hasattr (tqdm_bar , 'set_description_str' ):
146
176
tqdm_bar .set_description_str (f"Epoch:{ epoch + 1 } " )
147
177
if i > 0 :
148
178
tqdm_bar .set_postfix_str (
149
179
f"batch={ i } | { num_of_batches } , loss_step={ batch_loss :.5f} , loss_epoch={ epoch_loss / i :.5f} " )
150
180
else :
151
181
tqdm_bar .set_postfix_str (f"loss_step={ batch_loss :.5f} , loss_epoch={ batch_loss :.5f} " )
182
+ # Store the epoch loss
152
183
ensemble_model .loss_history .append (epoch_loss )
153
-
184
+ # Run on_fit_end callbacks after the training is done.
154
185
self .on_fit_end (self , ensemble_model )
155
186
# TODO: Later, maybe we should write a callback to save the models in disk
156
187
return ensemble_model
188
+
157
189
"""
158
190
159
191
def batchwisefit(self, *args, **kwargs):
0 commit comments