Skip to content

Commit c705036

Browse files
committed
begin developing a pure-MPI implementation
1 parent 086a5f1 commit c705036

File tree

2 files changed

+243
-184
lines changed

2 files changed

+243
-184
lines changed

bin/mpi-fastspecfit

Lines changed: 94 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,92 +9,115 @@ import numpy as np
99
from fastspecfit.logger import log
1010

1111

12-
def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None,
13-
makeqa=False, samplefile=None, input_redshifts=False,
14-
outdir_data='.', templates=None, templateversion=None,
15-
fphotodir=None, fphotofile=None):
12+
def get_size(comm, mp=1):
13+
# Number of rank=0 ranks in all the subcommunicators; also the unique
14+
# number of "colors".
15+
size = int(np.ceil(comm.size / mp))
16+
return size
17+
18+
19+
def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=False,
20+
sample=None, input_redshifts=False, outdir_data='.', templates=None,
21+
templateversion=None, fphotodir=None, fphotofile=None):
1622

1723
import sys
1824
from desispec.parallel import stdouterr_redirected
1925
from fastspecfit.mpi import plan
2026

21-
if comm is None:
22-
rank, size = 0, 1
27+
if comm:
28+
rank = comm.rank
29+
#size = comm.size
30+
size = get_size(comm, mp=args.mp)
31+
32+
# Split the MPI.COMM_WORLD communicator into size // args.mp
33+
# subcommunicators so we can parallelize over objects in
34+
# fastspecfit.fastspec (or fastspecfit.fastphot).
35+
colors = np.arange(size) // args.mp
36+
subcomm = comm.Split(color=rank // args.mp, key=rank)
2337
else:
24-
rank, size = comm.rank, comm.size
38+
rank = 0
39+
size = 1
40+
colors = [0]
41+
subcomm = None
42+
print(comm.rank, comm.size, subcomm.rank, subcomm.size, size)
2543

2644
t0 = time.time()
2745
if rank == 0:
28-
if args.samplefile is not None:
29-
import fitsio
30-
from astropy.table import Table
31-
if not os.path.isfile(args.samplefile):
32-
log.warning(f'{args.samplefile} does not exist.')
33-
return
34-
try:
35-
readcols = ['SURVEY', 'PROGRAM', 'HEALPIX', 'TARGETID']
36-
if input_redshifts:
37-
readcols += ['Z']
38-
sample = Table(fitsio.read(args.samplefile, columns=readcols))
39-
except:
40-
if input_redshifts:
41-
errmsg = f'Sample file {args.samplefile} with --input-redshifts set missing required columns {SURVEY,PROGRAM,HEALPIX,TARGETID,Z}'
42-
else:
43-
errmsg = f'Sample file {args.samplefile} missing required columns {SURVEY,PROGRAM,HEALPIX,TARGETID}'
44-
log.critical(errmsg)
45-
raise ValueError(errmsg)
46-
47-
_, zbestfiles, outfiles, groups, ntargets = plan(
46+
if sample is not None:
47+
_, redrockfiles, outfiles, groups, ntargets = plan(
4848
size=size, specprod=args.specprod, specprod_dir=specprod_dir,
4949
sample=sample, coadd_type='healpix', makeqa=args.makeqa,
50-
mp=args.mp, fastphot=args.fastphot,
51-
outdir_data=outdir_data, overwrite=args.overwrite)
50+
mp=args.mp, fastphot=args.fastphot, outdir_data=outdir_data,
51+
overwrite=args.overwrite)
5252
else:
53-
sample = None
54-
_, zbestfiles, outfiles, groups, ntargets = plan(
53+
_, redrockfiles, outfiles, groups, ntargets = plan(
5554
size=size, specprod=args.specprod, specprod_dir=specprod_dir,
5655
coadd_type=args.coadd_type, survey=args.survey, program=args.program,
5756
healpix=args.healpix, tile=args.tile, night=args.night,
5857
makeqa=args.makeqa, mp=args.mp, fastphot=fastphot, outdir_data=outdir_data,
5958
overwrite=args.overwrite)
60-
log.info('Planning took {:.2f} sec'.format(time.time() - t0))
59+
log.info(f'Planning took {time.time() - t0:.2f} sec')
6160
else:
62-
sample = None
63-
zbestfiles, outfiles, groups, ntargets = [], [], [], []
64-
65-
if comm:
66-
zbestfiles = comm.bcast(zbestfiles, root=0)
67-
outfiles = comm.bcast(outfiles, root=0)
68-
groups = comm.bcast(groups, root=0)
69-
ntargets = comm.bcast(ntargets, root=0)
70-
sample = comm.bcast(sample, root=0)
61+
redrockfiles, outfiles, groups, ntargets = [], [], [], []
62+
63+
#if comm:
64+
# groups = comm.bcast(groups, root=0)
65+
# redrockfiles = comm.bcast(redrockfiles, root=0)
66+
# outfiles = comm.bcast(outfiles, root=0)
67+
# ntargets = comm.bcast(ntargets, root=0)
68+
# sample = comm.bcast(sample, root=0)
69+
70+
print('Size! ', size)
71+
72+
# Make sure all the ranks in subcomm have the same work.
73+
if subcomm:
74+
if subcomm.rank == 0:
75+
for subrank in range(subcomm.size):
76+
subcomm.send(groups[rank], dest=subrank)
77+
subcomm.send(redrockfiles[groups[rank]], dest=subrank)
78+
subcomm.send(outfiles[groups[rank]], dest=subrank)
79+
subcomm.send(ntargets[groups[rank]], dest=subrank)
80+
else:
81+
groups[rank] = subcomm.recv(source=0)
82+
redrockfiles[groups[rank]] = subcomm.recv(source=0)
83+
outfiles[groups[rank]] = subcomm.recv(source=0)
84+
ntargets[groups[rank]] = subcomm.recv(source=0)
7185

7286
sys.stdout.flush()
7387

7488
# all done
75-
if len(zbestfiles) == 0:
89+
if len(redrockfiles) == 0:
7690
return
7791

78-
assert(len(groups) == size)
79-
assert(len(np.concatenate(groups)) == len(zbestfiles))
92+
#assert(len(groups) == size)
93+
#assert(len(np.concatenate(groups)) == len(redrockfiles))
94+
95+
"""
96+
16 redrockfiles
97+
size = 8
98+
mp = 4
99+
colors = np.arange(size) // mp --> [0, 0, 0, 0, 1, 1, 1, 1]
100+
nsubcomm = int(np.ceil(size / mp)) --> 2
101+
"""
102+
print(groups, rank, groups[rank])
80103

81104
for ii in groups[rank]:
82105
log.debug(f'Rank {rank} started at {time.asctime()}')
83106
sys.stdout.flush()
84107

85-
# With --makeqa the desired output directories are in the 'zbestfiles'.
108+
# With --makeqa the desired output directories are in the 'redrockfiles'.
86109
if args.makeqa:
87110
from fastspecfit.qa import fastqa as fast
88111
cmd = 'fastqa'
89-
cmdargs = f'{outfiles[ii]} -o={zbestfiles[ii]} --mp={args.mp}'
112+
cmdargs = f'{outfiles[ii]} -o={redrockfiles[ii]} --mp={args.mp}'
90113
else:
91114
if fastphot:
92115
from fastspecfit.fastspecfit import fastphot as fast
93116
cmd = 'fastphot'
94117
else:
95118
from fastspecfit.fastspecfit import fastspec as fast
96119
cmd = 'fastspec'
97-
cmdargs = f'{zbestfiles[ii]} -o={outfiles[ii]} --mp={args.mp}'
120+
cmdargs = f'{redrockfiles[ii]} -o={outfiles[ii]} --mp={args.mp}'
98121

99122
if args.ignore_quasarnet:
100123
cmdargs += ' --ignore-quasarnet'
@@ -128,7 +151,7 @@ def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None,
128151

129152
if sample is not None:
130153
# assume healpix coadds; find the targetids to process
131-
_, survey, program, healpix = os.path.basename(zbestfiles[ii]).split('-')
154+
_, survey, program, healpix = os.path.basename(redrockfiles[ii]).split('-')
132155
healpix = int(healpix.split('.')[0])
133156
I = (sample['SURVEY'] == survey) * (sample['PROGRAM'] == program) * (sample['HEALPIX'] == healpix)
134157
targetids = ','.join(sample[I]['TARGETID'].astype(str))
@@ -141,7 +164,7 @@ def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None,
141164
cmdargs += f' --targetids={args.targetids}'
142165

143166
if args.makeqa:
144-
logfile = os.path.join(zbestfiles[ii], os.path.basename(outfiles[ii]).replace('.gz', '').replace('.fits', '.log'))
167+
logfile = os.path.join(redrockfiles[ii], os.path.basename(outfiles[ii]).replace('.gz', '').replace('.fits', '.log'))
145168
else:
146169
logfile = outfiles[ii].replace('.gz', '').replace('.fits', '.log')
147170

@@ -166,17 +189,11 @@ def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None,
166189
if not os.path.isdir(outdir):
167190
os.makedirs(outdir, exist_ok=True)
168191

169-
## pure-MPI
170-
#if comm is not None:
171-
# subcomm = comm.Split(color=rank)
172-
#else:
173-
# subcomm = None
174-
175192
if args.nolog:
176-
fast(args=cmdargs.split())#, comm=subcomm)
193+
fast(args=cmdargs.split(), comm=subcomm)
177194
else:
178195
with stdouterr_redirected(to=logfile, overwrite=args.overwrite):
179-
fast(args=cmdargs.split())#, comm=subcomm)
196+
fast(args=cmdargs.split(), comm=subcomm)
180197

181198
dt1 = time.time() - t1
182199
log.info(f' rank {rank} done in {dt1:.2f} sec')
@@ -191,7 +208,7 @@ def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None,
191208
log.debug(f' rank {rank} is done')
192209
sys.stdout.flush()
193210

194-
if comm is not None:
211+
if comm:
195212
comm.barrier()
196213

197214
if rank == 0 and not args.dry_run:
@@ -273,20 +290,21 @@ def main():
273290
except ImportError:
274291
comm = None
275292

276-
if comm is None:
293+
if comm:
294+
rank = comm.rank
295+
if comm.size > 1 and args.mp > 1 and comm.size < args.mp:
296+
log.warning(f'Number of MPI tasks {comm.size} should be >{args.mp} for MPI parallelism.')
297+
size = get_size(comm, mp=args.mp)
298+
else:
277299
rank, size = 0, 1
278300

279301
# https://docs.nersc.gov/development/languages/python/parallel-python/#use-the-spawn-start-method
280302
if args.mp > 1 and 'NERSC_HOST' in os.environ:
281303
import multiprocessing
282304
multiprocessing.set_start_method('spawn')
283-
else:
284-
rank, size = comm.rank, comm.size
285305

286-
# Main rank is responsible for planning and merging.
306+
# Rank 0 is responsible for planning and merging.
287307
if rank == 0:
288-
#from fastspecfit.logger import log
289-
290308
# check the input samplefile
291309
if args.samplefile is not None:
292310
import fitsio
@@ -296,8 +314,14 @@ def main():
296314
return
297315
try:
298316
sample = Table(fitsio.read(args.samplefile, columns=['SURVEY', 'PROGRAM', 'HEALPIX', 'TARGETID']))
317+
log.info(f'Read {len(sample)} rows from {args.samplefile}')
299318
except:
300-
errmsg = f'Sample file {args.samplefile} missing required columns {SURVEY,PROGRAM,HEALPIX,TARGETID}'
319+
if args.input_redshifts:
320+
errmsg = f'Sample file {args.samplefile} with --input-redshifts missing required columns ' + \
321+
'{SURVEY,PROGRAM,HEALPIX,TARGETID,Z}'
322+
else:
323+
errmsg = f'Sample file {args.samplefile} missing required columns ' + \
324+
'{SURVEY,PROGRAM,HEALPIX,TARGETID}'
301325
log.critical(errmsg)
302326
raise ValueError(errmsg)
303327

@@ -376,10 +400,11 @@ def main():
376400
outdir_data=outdir_data, overwrite=args.overwrite)
377401
else:
378402
run_fastspecfit(args, comm=comm, fastphot=args.fastphot, specprod_dir=specprod_dir,
379-
makeqa=args.makeqa, outdir_data=outdir_data,
380-
samplefile=args.samplefile, input_redshifts=args.input_redshifts,
381-
templates=args.templates, templateversion=args.templateversion,
382-
fphotodir=args.fphotodir, fphotofile=args.fphotofile)
403+
makeqa=args.makeqa, outdir_data=outdir_data, sample=sample,
404+
input_redshifts=args.input_redshifts, templates=args.templates,
405+
templateversion=args.templateversion, fphotodir=args.fphotodir,
406+
fphotofile=args.fphotofile)
407+
383408

384409
if __name__ == '__main__':
385410
main()

0 commit comments

Comments
 (0)