Skip to content

Commit 58aa98c

Browse files
authored
Merge pull request #280 from dice-group/tensor_parallel
Tensor parallel
2 parents 94ab305 + 44b9dbd commit 58aa98c

File tree

3 files changed

+112
-68
lines changed

3 files changed

+112
-68
lines changed

dicee/models/ensemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self, seed_model):
99
for i in range(torch.cuda.device_count()):
1010
i_model=copy.deepcopy(seed_model)
1111
# TODO: Why we cant send the compile model to cpu ?
12-
i_model = torch.compile(i_model)
12+
#i_model = torch.compile(i_model)
1313
i_model.to(torch.device(f"cuda:{i}"))
1414
self.optimizers.append(i_model.configure_optimizers())
1515
self.models.append(i_model)

dicee/sanity_checkers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ def validate_knowledge_graph(args):
3232

3333
elif args.path_single_kg is not None:
3434
if args.sparql_endpoint is not None or args.path_single_kg is not None:
35-
print(f'The dataset_dir and sparql_endpoint arguments '
36-
f'must be None if path_single_kg is given.'
37-
f'***{args.dataset_dir}***\n'
38-
f'***{args.sparql_endpoint}***\n'
39-
f'These two parameters are set to None.')
35+
#print(f'The dataset_dir and sparql_endpoint arguments '
36+
# f'must be None if path_single_kg is given.'
37+
# f'***{args.dataset_dir}***\n'
38+
# f'***{args.sparql_endpoint}***\n'
39+
# f'These two parameters are set to None.')
4040
args.dataset_dir = None
4141
args.sparql_endpoint = None
4242

@@ -61,11 +61,11 @@ def validate_knowledge_graph(args):
6161
f"Use --path_single_kg **folder/dataset.format**, if you have a single file.")
6262

6363
if args.sparql_endpoint is not None or args.path_single_kg is not None:
64-
print(f'The sparql_endpoint and path_single_kg arguments '
65-
f'must be None if dataset_dir is given.'
66-
f'***{args.sparql_endpoint}***\n'
67-
f'***{args.path_single_kg}***\n'
68-
f'These two parameters are set to None.')
64+
#print(f'The sparql_endpoint and path_single_kg arguments '
65+
# f'must be None if dataset_dir is given.'
66+
# f'***{args.sparql_endpoint}***\n'
67+
# f'***{args.path_single_kg}***\n'
68+
# f'These two parameters are set to None.')
6969
args.sparql_endpoint = None
7070
args.path_single_kg = None
7171

dicee/trainer/model_parallelism.py

Lines changed: 101 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..static_funcs_training import make_iterable_verbose
44
from ..models.ensemble import EnsembleKGE
55
from typing import Tuple
6+
import time
67

78
def extract_input_outputs(z: list, device=None):
89
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
@@ -23,67 +24,99 @@ def extract_input_outputs(z: list, device=None):
2324
else:
2425
raise ValueError('Unexpected batch shape..')
2526

26-
def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.1):
27+
28+
def find_good_batch_size(train_loader,tp_ensemble_model):
2729
# () Initial batch size
28-
batch_size=train_loader.batch_size
29-
if batch_size >= len(train_loader.dataset):
30-
return batch_size
31-
first_batch_size = train_loader.batch_size
32-
num_datapoints=len(train_loader.dataset)
33-
print(f"Increment the batch size by {first_batch_size} until the Free/Total GPU memory is reached to {1-max_available_gpu_memory} or batch_size={num_datapoints} is achieved.")
34-
while True:
35-
# () Initialize a dataloader with a current batch_size
36-
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
37-
batch_size=batch_size,
38-
shuffle=True,
39-
sampler=None,
40-
batch_sampler=None,
41-
num_workers=0,
42-
collate_fn=train_loader.dataset.collate_fn,
43-
pin_memory=False, drop_last=False,
44-
timeout=0,
45-
worker_init_fn=None,
46-
persistent_workers=False)
47-
loss=None
48-
avg_global_free_memory=[]
49-
for i, z in enumerate(train_dataloaders):
50-
loss = forward_backward_update_loss(z,ensemble_model)
51-
global_free_memory, total_memory = torch.cuda.mem_get_info()
52-
avg_global_free_memory.append(global_free_memory / total_memory)
53-
if i==3:
54-
break
55-
avg_global_free_memory=sum(avg_global_free_memory)/len(avg_global_free_memory)
56-
print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}")
57-
if avg_global_free_memory > max_available_gpu_memory and batch_size < num_datapoints :
58-
if batch_size+first_batch_size <= num_datapoints:
59-
batch_size+=first_batch_size
60-
else:
61-
batch_size=num_datapoints
30+
initial_batch_size=train_loader.batch_size
31+
training_dataset_size=len(train_loader.dataset)
32+
if initial_batch_size >= training_dataset_size:
33+
return training_dataset_size, None
34+
print("Number of training data points:",training_dataset_size)
35+
36+
def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size,delta: int = None):
37+
assert delta is not None, "delta must be positive integer"
38+
batch_sizes_and_mem_usages = []
39+
num_datapoints = len(train_loader.dataset)
40+
try:
41+
while True:
42+
start_time=time.time()
43+
# () Initialize a dataloader with a current batch_size
44+
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
45+
batch_size=batch_size,
46+
shuffle=True,
47+
sampler=None,
48+
batch_sampler=None,
49+
num_workers=train_loader.num_workers,
50+
collate_fn=train_loader.dataset.collate_fn,
51+
pin_memory=False,
52+
drop_last=False,
53+
timeout=0,
54+
worker_init_fn=None,
55+
persistent_workers=False)
56+
57+
batch_loss = None
58+
for i, batch_of_training_data in enumerate(train_dataloaders):
59+
batch_loss = forward_backward_update_loss(batch_of_training_data, ensemble_model)
60+
break
61+
62+
global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
63+
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory
64+
rt=time.time()-start_time
65+
print(f"Random Batch Loss: {batch_loss:0.4}\tGPU Usage: {percentage_used_gpu_memory:0.3}\tRuntime: {rt:.3f}\tBatch Size: {batch_size}")
66+
67+
global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
68+
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory
69+
70+
# Store the batch size and the runtime
71+
batch_sizes_and_mem_usages.append((batch_size, rt))
72+
73+
if batch_size < num_datapoints:
74+
# Increase the batch size.
75+
batch_size += int(batch_size / delta)
76+
else:
77+
return batch_sizes_and_mem_usages,True
78+
79+
except torch.OutOfMemoryError:
80+
print("torch.OutOfMemoryError caught!")
81+
return batch_sizes_and_mem_usages, False
82+
83+
history_batch_sizes_and_mem_usages=[]
84+
batch_size=initial_batch_size
85+
86+
for delta in range(1,5,1):
87+
result,flag= increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta)
88+
89+
history_batch_sizes_and_mem_usages.extend(result)
90+
91+
if flag:
92+
batch_size, batch_rt = history_batch_sizes_and_mem_usages[-1]
93+
else:
94+
# CUDA ERROR Observed
95+
batch_size, batch_rt=history_batch_sizes_and_mem_usages[-2]
96+
97+
if batch_size>=training_dataset_size:
98+
batch_size=training_dataset_size
99+
break
62100
else:
63-
assert batch_size<=num_datapoints
64-
if batch_size == num_datapoints:
65-
print("Batch size equals to the training dataset size")
66-
else:
67-
print(f"Max GPU memory used\tFree/Total GPU Memory:{avg_global_free_memory}")
68-
return batch_size
69-
70-
def forward_backward_update_loss(z:Tuple, ensemble_model):
71-
# () Get the i-th batch of data points.
101+
continue
102+
103+
return batch_size, batch_rt
104+
105+
106+
def forward_backward_update_loss(z:Tuple, ensemble_model)->float:
107+
# () Get a random batch of data points (z).
72108
x_batch, y_batch = extract_input_outputs(z)
73-
# () Move the batch of labels into the master GPU : GPU-0
109+
# () Move the batch of labels into the master GPU : GPU-0.
74110
y_batch = y_batch.to("cuda:0")
75-
# () Forward Pass on the batch. Yhat located on the master GPU.
111+
# () Forward pas on the batch of input data points (yhat on the master GPU).
76112
yhat = ensemble_model(x_batch)
77-
# () Compute the loss
113+
# () Compute the loss.
78114
loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_batch)
79115
# () Compute the gradient of the loss w.r.t. parameters.
80116
loss.backward()
81117
# () Parameter update.
82118
ensemble_model.step()
83-
# () Report the batch and epoch losses.
84-
batch_loss = loss.item()
85-
# () Accumulate batch loss
86-
return batch_loss
119+
return loss.item()
87120

88121
class TensorParallel(AbstractTrainer):
89122
def __init__(self, args, callbacks):
@@ -103,12 +136,14 @@ def fit(self, *args, **kwargs):
103136
self.on_fit_start(self, ensemble_model)
104137
# () Sanity checking
105138
assert torch.cuda.device_count()== len(ensemble_model)
106-
# ()
139+
# () Get DataLoader
107140
train_dataloader = kwargs['train_dataloaders']
108-
# ()
141+
# () Find a batch size so that available GPU memory is *almost* fully used.
109142
if self.attributes.auto_batch_finding:
143+
batch_size, batch_rt=find_good_batch_size(train_dataloader, ensemble_model)
144+
110145
train_dataloader = torch.utils.data.DataLoader(train_dataloader.dataset,
111-
batch_size=find_good_batch_size(train_dataloader, ensemble_model),
146+
batch_size=batch_size,
112147
shuffle=True,
113148
sampler=None,
114149
batch_sampler=None,
@@ -119,29 +154,38 @@ def fit(self, *args, **kwargs):
119154
timeout=0,
120155
worker_init_fn=None,
121156
persistent_workers=False)
157+
if batch_rt is not None:
158+
expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs
159+
print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}")
122160

161+
# () Number of batches to reach a single epoch.
123162
num_of_batches = len(train_dataloader)
124163
# () Start training.
125164
for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs),
126165
verbose=True, position=0, leave=True)):
166+
# () Accumulate the batch losses.
127167
epoch_loss = 0
128168
# () Iterate over batches.
129169
for i, z in enumerate(train_dataloader):
170+
# () Forward, Loss, Backward, and Update on a given batch of data points.
130171
batch_loss = forward_backward_update_loss(z,ensemble_model)
172+
# () Accumulate the batch losses to compute the epoch loss.
131173
epoch_loss += batch_loss
132-
174+
# if verbose=TRue, show info.
133175
if hasattr(tqdm_bar, 'set_description_str'):
134176
tqdm_bar.set_description_str(f"Epoch:{epoch + 1}")
135177
if i > 0:
136178
tqdm_bar.set_postfix_str(
137179
f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
138180
else:
139181
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")
182+
# Store the epoch loss
140183
ensemble_model.loss_history.append(epoch_loss)
141-
184+
# Run on_fit_end callbacks after the training is done.
142185
self.on_fit_end(self, ensemble_model)
143186
# TODO: Later, maybe we should write a callback to save the models in disk
144187
return ensemble_model
188+
145189
"""
146190
147191
def batchwisefit(self, *args, **kwargs):
@@ -242,4 +286,4 @@ def torch_buggy_fit(self, *args, **kwargs):
242286
torch.distributed.destroy_process_group()
243287
# () .
244288
self.on_fit_end(self, model)
245-
"""
289+
"""

0 commit comments

Comments
 (0)