Skip to content

Commit 587bf59

Browse files
committed
mpi4py block-wise mandelbrot implementation with Gatherv
1 parent 65d0d3b commit 587bf59

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,12 @@
22

33
Materials for teaching a 'Python in HPC' class. Includes python notebooks,
44
scripts, and slides.
5+
6+
<dl>
7+
<dt>profiling_and_optimizing</dt>
8+
<dd>Notebook showing how to profile and improve performance
9+
of a simple minded Mandelbrot set implementation. Includes
10+
numba, numpy, cython, f2py, numba.vectorize, and multiprocessing.</dd>
11+
<dt>parallel_code_examples</dt>
12+
<dd>Python example code for cross-machine parallization</dd>
13+
</dl>
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#! /usr/bin/env python
2+
"""
3+
calculates a mandel set using block distribution - i.e.
4+
rank 0 calulates lines [0,n1[
5+
rank 1 [n1,n2[
6+
7+
Run with
8+
mpiexec [mpiexec otions] ./mandelbrot_mpi.py
9+
10+
This code was developed with
11+
mpi4py 2.0.0
12+
numba 0.35.0
13+
numpy 1.11.3
14+
python 2.7.13
15+
"""
16+
17+
###
18+
### imports
19+
###
20+
21+
from mpi4py import MPI
22+
import numpy as np
23+
from numba import jit
24+
25+
tic = MPI.Wtime()
26+
27+
###
28+
### globals
29+
###
30+
31+
# area1:
32+
#xmin, xmax = -2.0, 0.5
33+
#ymin, ymax = -1.25, 1.25
34+
# maxiter = 80
35+
36+
# area2:
37+
xmin, xmax = -0.74877, -0.74872
38+
ymin, ymax = 0.06505, 0.06510
39+
width, height = 3000, 3000
40+
maxiter = 2048
41+
42+
dy = (ymax - ymin) / (height - 1)
43+
44+
###
45+
### functions
46+
###
47+
48+
@jit
49+
def mandel(creal, cimag, maxiter):
50+
real = creal
51+
imag = cimag
52+
for n in range(maxiter):
53+
real2 = real*real
54+
imag2 = imag*imag
55+
if real2 + imag2 > 4.0:
56+
return n
57+
imag = 2 * real*imag + cimag
58+
real = real2 - imag2 + creal
59+
return n
60+
61+
@jit
62+
def mandel_set(xmin, xmax, ymin, ymax, width, height, maxiter):
63+
r = np.linspace(xmin, xmax, width)
64+
i = np.linspace(ymin, ymax, height)
65+
n = np.empty((height, width), dtype='i')
66+
for x in range(width):
67+
for y in range(height):
68+
n[y, x] = mandel(r[x], i[y], maxiter)
69+
return n
70+
71+
###
72+
### main
73+
###
74+
75+
comm = MPI.COMM_WORLD
76+
size = comm.Get_size()
77+
rank = comm.Get_rank()
78+
79+
print "Rank {:4d}: checking in".format(rank)
80+
81+
# how many rows to compute in this rank?
82+
# for example: height 100
83+
# size 8
84+
# then height % size = 4
85+
# which means that ranks 0 - 3 each do one
86+
# more row (13) than ranks 4 - 7 (12)
87+
N = height // size + (height % size > rank)
88+
print "Rank {}: will compute {} rows".format(rank, N)
89+
N = np.array(N, dtype='i') # so we can Gather it later on
90+
91+
# first row to compute here
92+
# scan: the operation returns for each rank i the sum of send buffers of ranks [0,i]
93+
# nifty!
94+
# start_y and end_y are the first and last y value to calculate in this block
95+
start_i = comm.scan(N) - N
96+
start_y = ymin + start_i * dy
97+
end_y = ymin + (start_i + N - 1) * dy
98+
print "Rank {:4d}: will compute y = [{}, {}]".format(rank, start_y, end_y)
99+
100+
# calculate the local results
101+
Cl = mandel_set(xmin, xmax, start_y, end_y, width, N, maxiter)
102+
print "Rank {:4d}: finished computing rows; result matrix is shape {}".format(rank, Cl.shape)
103+
print "Rank {:4d}: max value in array: {}".format(rank, Cl.max())
104+
105+
# gather the number of rows calculated by each rank. returns a list
106+
# note this is the lower case 'gather' used for python objects (slow)
107+
# though in this case the data set is tiny and it wouldn't have mattered
108+
rowcounts = 0 # has to be zero, not None b/c of the 'rowcounts * width' bit later on
109+
C = None
110+
if rank == 0:
111+
rowcounts = np.empty(size, dtype='i')
112+
C = np.zeros([height, width], dtype='i')
113+
114+
comm.Gather(sendbuf = [N, MPI.INT],
115+
recvbuf = [rowcounts, MPI.INT],
116+
root = 0)
117+
118+
# gather the global results matrix
119+
# note: Gatherv allows varying amounts of data from each rank. In the underlying
120+
# MPI implementation the receiving buffer has to specify how many elements to
121+
# expect from each rank, and at what position they should be inserved into the
122+
# receiver buffer. I think the 'None' make mpi4py automatically figure out
123+
# the displacements. There is very little documentation on Gatherv in mpi4py and
124+
# the examples i've found all differ.
125+
126+
comm.Gatherv(sendbuf = [Cl, MPI.INT],
127+
recvbuf = [C, (rowcounts * width, None), MPI.INT],
128+
root = 0)
129+
130+
toc = MPI.Wtime()
131+
132+
wct = comm.gather(toc - tic, root=0)
133+
if rank == 0:
134+
for task, time in enumerate(wct):
135+
print "Rank {:4d}: ran for {:8.2f}s".format(task, time)
136+
print "max(runtime) = {:8.2f}s".format(max(wct))
137+
print "min(runtime ) = {:8.2f}s".format(min(wct))
138+
print "mean(runtime) = {:8.2f}s".format(sum(wct) / len(wct))
139+
print "Array size: {} x {}".format(height, width)
140+
141+
# eye candy (requires matplotlib)
142+
if rank == 0 and width * height <= 1e7:
143+
try:
144+
from matplotlib import pyplot as plt
145+
from matplotlib import colors
146+
except ImportError:
147+
print ('No matplotlib found; skipping plot')
148+
else:
149+
norm = colors.PowerNorm(0.3)
150+
figsz = max(width, height) / 100
151+
fig = plt.figure(figsize=(figsz, figsz), dpi=100, tight_layout=True)
152+
ax = fig.add_subplot(111)
153+
ax.imshow(C, cmap='magma', norm=norm, origin='lower', aspect='equal')
154+
ax.set_xticks([])
155+
ax.set_yticks([])
156+
fig.savefig('mandelbrot.png')
157+
MPI.COMM_WORLD.Barrier()

0 commit comments

Comments
 (0)