Skip to content

Commit

Permalink
Improved batch finding in TP
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 28, 2024
1 parent d3081a1 commit 5551241
Showing 1 changed file with 64 additions and 51 deletions.
115 changes: 64 additions & 51 deletions dicee/trainer/model_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,66 +23,79 @@ def extract_input_outputs(z: list, device=None):
else:
raise ValueError('Unexpected batch shape..')

def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.20):

def find_good_batch_size(train_loader,tp_ensemble_model):
# () Initial batch size
batch_size=train_loader.batch_size
if batch_size >= len(train_loader.dataset):
return batch_size
first_batch_size = train_loader.batch_size
num_datapoints=len(train_loader.dataset)
print(f"Increment the batch size by {first_batch_size} until the Free/Total GPU memory is reached to {1-max_available_gpu_memory} or batch_size={num_datapoints} is achieved.")
while True:
# () Initialize a dataloader with a current batch_size
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=train_loader.dataset.collate_fn,
pin_memory=False, drop_last=False,
timeout=0,
worker_init_fn=None,
persistent_workers=False)
loss=None
avg_global_free_memory=[]
for i, z in enumerate(train_dataloaders):
loss = forward_backward_update_loss(z,ensemble_model)
global_free_memory, total_memory = torch.cuda.mem_get_info()
break

avg_global_free_memory= global_free_memory / total_memory

print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}")
if avg_global_free_memory > max_available_gpu_memory and batch_size < num_datapoints :
if batch_size <= num_datapoints:
batch_size+=batch_size
else:
batch_size=num_datapoints
else:
if batch_size == num_datapoints:
print("Batch size equals to the training dataset size")
else:
print(f"Max GPU memory used\tFree/Total GPU Memory:{avg_global_free_memory}")
return batch_size

def forward_backward_update_loss(z:Tuple, ensemble_model):
# () Get the i-th batch of data points.
initial_batch_size=train_loader.batch_size
if initial_batch_size >= len(train_loader.dataset):
return initial_batch_size

def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size, delta: int = None):
assert delta is not None, "delta must be positive integer"
batch_sizes_and_mem_usages = []
num_datapoints = len(train_loader.dataset)
try:
while True:
# () Initialize a dataloader with a current batch_size
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=train_loader.dataset.collate_fn,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
persistent_workers=False)
batch_loss = None
for i, batch_of_training_data in enumerate(train_dataloaders):
batch_loss = forward_backward_update_loss(batch_of_training_data, ensemble_model)
break
global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory

print(
f"Random Batch Loss: {batch_loss}\tGPU Usage: {percentage_used_gpu_memory}\tBatch Size:{batch_size}")

global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory

batch_sizes_and_mem_usages.append((batch_size, percentage_used_gpu_memory))
if batch_size < num_datapoints:
batch_size += int(batch_size / delta)
else:
if batch_size == num_datapoints:
print("Batch size equals to the training dataset size")
break
except torch.OutOfMemoryError:
print(f"torch.OutOfMemoryError caught!")
return batch_sizes_and_mem_usages

history_batch_sizes_and_mem_usages=[]
batch_size=initial_batch_size
for delta in range(1,10,1):
history_batch_sizes_and_mem_usages.extend(increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta))
batch_size=history_batch_sizes_and_mem_usages[-2][0]
print(f"A best found batch size:{batch_size} in {len(history_batch_sizes_and_mem_usages)} trials. Current GPU memory usage % :{history_batch_sizes_and_mem_usages[-2][1]}")
return batch_size


def forward_backward_update_loss(z:Tuple, ensemble_model)->float:
# () Get a random batch of data points (z).
x_batch, y_batch = extract_input_outputs(z)
# () Move the batch of labels into the master GPU : GPU-0
# () Move the batch of labels into the master GPU : GPU-0.
y_batch = y_batch.to("cuda:0")
# () Forward Pass on the batch. Yhat located on the master GPU.
# () Forward pas on the batch of input data points (yhat on the master GPU).
yhat = ensemble_model(x_batch)
# () Compute the loss
# () Compute the loss.
loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_batch)
# () Compute the gradient of the loss w.r.t. parameters.
loss.backward()
# () Parameter update.
ensemble_model.step()
# () Report the batch and epoch losses.
batch_loss = loss.item()
# () Accumulate batch loss
return batch_loss
return loss.item()

class TensorParallel(AbstractTrainer):
def __init__(self, args, callbacks):
Expand Down

0 comments on commit 5551241

Please sign in to comment.