-
Notifications
You must be signed in to change notification settings - Fork 6
/
pytorch-multi-gpu-multi-process-testing.py
113 lines (91 loc) · 3.55 KB
/
pytorch-multi-gpu-multi-process-testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# --------------------------------------------------------
# Pytorch-Multi-GPU-Testing
# Written by Jingyun Liang
# --------------------------------------------------------
import os
import random
import time
import warnings
import torch
import multiprocessing
# global variables
total_gpu_num = 8
# max_process_per_gpu=1 always works fine; max_process_per_gpu>=2 may get stuck on some computers or clusters.
max_process_per_gpu = 1
used_gpu_list = multiprocessing.Manager().list([0] * total_gpu_num)
lock = multiprocessing.Lock()
class CNN(torch.nn.Module):
'''
A toy CNN.
'''
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
return x
def multi_gpu_testing_wrapper(model, input, index, gpu_id=None, available_gpu_num=1):
'''
Multi-GPU testing wrapper.
:param model: The PyTorch model (on cpu, NOT on cuda).
:param input: The model input, e.g. an image (on cpu, NOT on cuda).
:param index: Sample index (int).
:param gpu_id: Given gpu_id. Only used in debugging.
:param available_gpu_num: Available GPU number. Default: 1.
:return: Model output (on cpu, NOT on cuda), used GPU id and process id.
'''
# GPU assignment
lock.acquire()
if gpu_id is None:
for i in range(available_gpu_num):
if used_gpu_list[i] < max_process_per_gpu and used_gpu_list[i] == min(used_gpu_list):
gpu_id = i
break
used_gpu_list[gpu_id] += 1
lock.release()
torch.cuda.set_device(gpu_id)
device = torch.device('cuda')
print(f'testing input {index} on GPU {gpu_id}. Overall GPU usages: ', list(used_gpu_list))
# model testing
input = input.to(device)
model = model.to(device)
time.sleep(random.randrange(0, 10)) # used in this toy example to avoid deadlock
output = model(input).detach().cpu()
# release GPU memory manually (multiprocessing.Pool may not release GPU memory of a process in time)
del input, model
torch.cuda.empty_cache()
# release GPU
lock.acquire()
used_gpu_list[gpu_id] -= 1
lock.release()
print(f'releasing input {index} on GPU {gpu_id}. Overall GPU usages: ', list(used_gpu_list))
# return output
return output, gpu_id, os.getpid()
def main():
# setup GPU
available_gpu_num = 4
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
assert available_gpu_num <= total_gpu_num
if max_process_per_gpu > 1:
warnings.warn("max_process_per_gpu>=2 may get stuck on some computers or clusters.")
# initialize input and model
total_input_num = 10
input = [torch.ones(1, 1, 2, 2)] * total_input_num
model = CNN()
output = []
def mycallback(arg):
output.append(arg[0])
# test the model on multiple GPUs distributedly
pool = multiprocessing.Pool(available_gpu_num * max_process_per_gpu)
for i in range(total_input_num):
# hint: pool.apply_async cannot output informative debugging logs. Use pool.apply() for debugging.
pool.apply_async(multi_gpu_testing_wrapper, args=(model, input[i], i, None, available_gpu_num), callback=mycallback)
# pool.apply(multi_gpu_testing_wrapper, args=(model, input[i], i, 0, available_gpu_num))
pool.close()
pool.join()
print('All subprocesses done.')
# check output quality (sometimes a process may fail due to out-of-momory error, but there is no error!)
print(f'\n{len(output)}/{len(input)} processes succeeded.')
assert len(input) == len(output)
if __name__ == '__main__':
main()