Skip to content

Commit fbf30ab

Browse files
committed
WIP: Reducing the runtime of finding a good search & removing redandant log infos
1 parent 5551241 commit fbf30ab

File tree

2 files changed

+65
-33
lines changed

2 files changed

+65
-33
lines changed

dicee/sanity_checkers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ def validate_knowledge_graph(args):
3232

3333
elif args.path_single_kg is not None:
3434
if args.sparql_endpoint is not None or args.path_single_kg is not None:
35-
print(f'The dataset_dir and sparql_endpoint arguments '
36-
f'must be None if path_single_kg is given.'
37-
f'***{args.dataset_dir}***\n'
38-
f'***{args.sparql_endpoint}***\n'
39-
f'These two parameters are set to None.')
35+
#print(f'The dataset_dir and sparql_endpoint arguments '
36+
# f'must be None if path_single_kg is given.'
37+
# f'***{args.dataset_dir}***\n'
38+
# f'***{args.sparql_endpoint}***\n'
39+
# f'These two parameters are set to None.')
4040
args.dataset_dir = None
4141
args.sparql_endpoint = None
4242

@@ -61,11 +61,11 @@ def validate_knowledge_graph(args):
6161
f"Use --path_single_kg **folder/dataset.format**, if you have a single file.")
6262

6363
if args.sparql_endpoint is not None or args.path_single_kg is not None:
64-
print(f'The sparql_endpoint and path_single_kg arguments '
65-
f'must be None if dataset_dir is given.'
66-
f'***{args.sparql_endpoint}***\n'
67-
f'***{args.path_single_kg}***\n'
68-
f'These two parameters are set to None.')
64+
#print(f'The sparql_endpoint and path_single_kg arguments '
65+
# f'must be None if dataset_dir is given.'
66+
# f'***{args.sparql_endpoint}***\n'
67+
# f'***{args.path_single_kg}***\n'
68+
# f'These two parameters are set to None.')
6969
args.sparql_endpoint = None
7070
args.path_single_kg = None
7171

dicee/trainer/model_parallelism.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..static_funcs_training import make_iterable_verbose
44
from ..models.ensemble import EnsembleKGE
55
from typing import Tuple
6+
import time
67

78
def extract_input_outputs(z: list, device=None):
89
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
@@ -27,59 +28,79 @@ def extract_input_outputs(z: list, device=None):
2728
def find_good_batch_size(train_loader,tp_ensemble_model):
2829
# () Initial batch size
2930
initial_batch_size=train_loader.batch_size
30-
if initial_batch_size >= len(train_loader.dataset):
31-
return initial_batch_size
31+
training_dataset_size=len(train_loader.dataset)
32+
if initial_batch_size >= training_dataset_size:
33+
return training_dataset_size, None
34+
print("Number of training data points:",training_dataset_size)
3235

33-
def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size, delta: int = None):
36+
def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size,delta: int = None):
3437
assert delta is not None, "delta must be positive integer"
3538
batch_sizes_and_mem_usages = []
3639
num_datapoints = len(train_loader.dataset)
3740
try:
3841
while True:
42+
start_time=time.time()
3943
# () Initialize a dataloader with a current batch_size
4044
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
4145
batch_size=batch_size,
4246
shuffle=True,
4347
sampler=None,
4448
batch_sampler=None,
45-
num_workers=0,
49+
num_workers=train_loader.num_workers,
4650
collate_fn=train_loader.dataset.collate_fn,
4751
pin_memory=False,
4852
drop_last=False,
4953
timeout=0,
5054
worker_init_fn=None,
5155
persistent_workers=False)
56+
5257
batch_loss = None
5358
for i, batch_of_training_data in enumerate(train_dataloaders):
5459
batch_loss = forward_backward_update_loss(batch_of_training_data, ensemble_model)
5560
break
61+
5662
global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
5763
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory
58-
59-
print(
60-
f"Random Batch Loss: {batch_loss}\tGPU Usage: {percentage_used_gpu_memory}\tBatch Size:{batch_size}")
64+
rt=time.time()-start_time
65+
print(f"Random Batch Loss: {batch_loss:0.4}\tGPU Usage: {percentage_used_gpu_memory:0.3}\tRuntime: {rt:.3f}\tBatch Size: {batch_size}")
6166

6267
global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
6368
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory
64-
65-
batch_sizes_and_mem_usages.append((batch_size, percentage_used_gpu_memory))
69+
70+
# Store the batch size and the runtime
71+
batch_sizes_and_mem_usages.append((batch_size, rt))
72+
6673
if batch_size < num_datapoints:
74+
# Increase the batch size.
6775
batch_size += int(batch_size / delta)
6876
else:
69-
if batch_size == num_datapoints:
70-
print("Batch size equals to the training dataset size")
71-
break
77+
return batch_sizes_and_mem_usages,True
78+
7279
except torch.OutOfMemoryError:
7380
print(f"torch.OutOfMemoryError caught!")
74-
return batch_sizes_and_mem_usages
81+
return batch_sizes_and_mem_usages, False
7582

7683
history_batch_sizes_and_mem_usages=[]
7784
batch_size=initial_batch_size
78-
for delta in range(1,10,1):
79-
history_batch_sizes_and_mem_usages.extend(increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta))
80-
batch_size=history_batch_sizes_and_mem_usages[-2][0]
81-
print(f"A best found batch size:{batch_size} in {len(history_batch_sizes_and_mem_usages)} trials. Current GPU memory usage % :{history_batch_sizes_and_mem_usages[-2][1]}")
82-
return batch_size
85+
86+
for delta in range(1,5,1):
87+
result,flag= increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta)
88+
89+
history_batch_sizes_and_mem_usages.extend(result)
90+
91+
if flag:
92+
batch_size, batch_rt = history_batch_sizes_and_mem_usages[-1]
93+
else:
94+
# CUDA ERROR Observed
95+
batch_size, batch_rt=history_batch_sizes_and_mem_usages[-2]
96+
97+
if batch_size>=training_dataset_size:
98+
batch_size=training_dataset_size
99+
break
100+
else:
101+
continue
102+
103+
return batch_size, batch_rt
83104

84105

85106
def forward_backward_update_loss(z:Tuple, ensemble_model)->float:
@@ -115,12 +136,14 @@ def fit(self, *args, **kwargs):
115136
self.on_fit_start(self, ensemble_model)
116137
# () Sanity checking
117138
assert torch.cuda.device_count()== len(ensemble_model)
118-
# ()
139+
# () Get DataLoader
119140
train_dataloader = kwargs['train_dataloaders']
120-
# ()
141+
# () Find a batch size so that available GPU memory is *almost* fully used.
121142
if self.attributes.auto_batch_finding:
143+
batch_size, batch_rt=find_good_batch_size(train_dataloader, ensemble_model)
144+
122145
train_dataloader = torch.utils.data.DataLoader(train_dataloader.dataset,
123-
batch_size=find_good_batch_size(train_dataloader, ensemble_model),
146+
batch_size=batch_size,
124147
shuffle=True,
125148
sampler=None,
126149
batch_sampler=None,
@@ -131,29 +154,38 @@ def fit(self, *args, **kwargs):
131154
timeout=0,
132155
worker_init_fn=None,
133156
persistent_workers=False)
157+
if batch_rt is not None:
158+
expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs
159+
print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}")
134160

161+
# () Number of batches to reach a single epoch.
135162
num_of_batches = len(train_dataloader)
136163
# () Start training.
137164
for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs),
138165
verbose=True, position=0, leave=True)):
166+
# () Accumulate the batch losses.
139167
epoch_loss = 0
140168
# () Iterate over batches.
141169
for i, z in enumerate(train_dataloader):
170+
# () Forward, Loss, Backward, and Update on a given batch of data points.
142171
batch_loss = forward_backward_update_loss(z,ensemble_model)
172+
# () Accumulate the batch losses to compute the epoch loss.
143173
epoch_loss += batch_loss
144-
174+
# if verbose=TRue, show info.
145175
if hasattr(tqdm_bar, 'set_description_str'):
146176
tqdm_bar.set_description_str(f"Epoch:{epoch + 1}")
147177
if i > 0:
148178
tqdm_bar.set_postfix_str(
149179
f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
150180
else:
151181
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")
182+
# Store the epoch loss
152183
ensemble_model.loss_history.append(epoch_loss)
153-
184+
# Run on_fit_end callbacks after the training is done.
154185
self.on_fit_end(self, ensemble_model)
155186
# TODO: Later, maybe we should write a callback to save the models in disk
156187
return ensemble_model
188+
157189
"""
158190
159191
def batchwisefit(self, *args, **kwargs):

0 commit comments

Comments
 (0)