-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathring_analysis_MPI.py
419 lines (304 loc) · 17.5 KB
/
ring_analysis_MPI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
#!/usr/bin/env python
import numpy as np
import pickle
from os import path
from glob import glob
from qcnico.lattice import cartesian_product
from qcnico.graph_tools import adjacency_matrix_sparse, count_rings, cycle_centers, hexagon_adjmat
from scipy import sparse
from time import perf_counter
def subsample_MAC_half_step(pos,l,m,n,m_max,n_max,return_global_indices=False):
"""Returns the l*l square sample of a L*L MAC structure corresponding to coordinates m*l ≤ x ≤ (m+1)*l and
n*l ≤ y ≤ (n+1)*l.
This half-step sampling ensures that no region of the sample will lie solely on the edge of a subsample.
If `return_global_indices` is set to True, this function will also return the indices of the subsampled atoms in
full `pos` array."""
N = pos.shape[0]
if m == 0:
x_mask1 = np.ones(N,dtype=bool)
else:
x_mask1 = pos[:,0] >= m * l
if m == m_max:
x_mask2 = np.ones(N,dtype=bool)
else:
x_mask2 = pos[:,0] < (m+2) * l
if n == 0:
y_mask1 = np.ones(N,dtype=bool)
else:
y_mask1 = pos[:,1] >= n * l
if n == n_max:
y_mask2 = np.ones(N,dtype=bool)
else:
y_mask2 = pos[:,1] < (n+2) * l
mask = x_mask1 * x_mask2 * y_mask1 * y_mask2
if return_global_indices:
return pos[mask,:], mask.nonzero()[0]
else:
return pos[mask,:]
def get_rings_from_subsamp(full_pos,nprocs,rank, nn, save_explicit_rings=True, max_ring_size=8,outdir='.',save_hexs_only=False):
"""This is a driver function which runs the MPI-parallelised ring analysis code for a given structure.
Given a number of MPI jobs `nprocs` (which HAS to be a perfect square), it determines the size of the subsamples,
divides the structure up appropriately, and does the ring analysis (i.e. identifies the rings, their centers, and
builds the local atom and hexagon adjacency matrices).
If `save_explicit_rings` is set to `True`, then this code will save the rings (each represented by the list of
indices of its component atoms) as pickled list.
The variable `nn` is basically just a label, useful to differentiate outputs from different structures if multiple
structures are stored in the same file (i.e. mupltiple Ata test structures in the same NPY).
Kwarg `save_hexs_only` has been added to re-generate MPI-segmented hexagon lists that I deleted ('^^) but need to construct the
crystallinity masks. It avoids writing all of the other output files that this function would otherwise generate and thus
potentially making me reach my file quota."""
if save_hexs_only:
print('!!!!! WARNING !!!!!\n`save_hexs_only` kwarg has been set to True; `get_rings_from_subsamp` will not write many of the files necessary to rebuild the rings from this MPI-segmented run.')
full_pos = full_pos[:,:2]
a = np.sqrt(nprocs) #nprocs should always be a perfect square
if a % 1 != 0:
print('ERROR: Number of MPI processes must be a perfect square! It is currently set to: ', nprocs)
print('Returning None.')
return None
a = int(a)
Lx = np.max(full_pos[:,0]) - np.min(full_pos[:,0])
Ly = np.max(full_pos[:,1]) - np.min(full_pos[:,1])
L = np.max(np.ceil([Lx,Ly]))
l = L // (a+1) # I did the math, this ensures a proper partitioning of the structure into nprocs subsamples
ii_sample = cartesian_product(np.arange(a),np.arange(a))
m,n = ii_sample[rank]
print(f'[{rank+1}] Sample indices: ({m,n})')
if save_explicit_rings:
pos, iatoms = subsample_MAC_half_step(full_pos,l,m,n,a-1,a-1,return_global_indices=True)
else:
pos = subsample_MAC_half_step(full_pos,l,m,n,a-1,a-1,return_global_indices=False)
if not save_hexs_only:
np.save(path.join(outdir,f'pos_sample-{nn}_{m}_{n}.npy'),pos)
rCC = 1.8
_, rings, M = count_rings(pos,rCC,max_size=max_ring_size,return_cycles=True,return_M=True)
hexs = [c for c in rings if len(c)==6]
ring_lengths = np.array([len(c) for c in rings])
ring_centers = cycle_centers(rings, pos)
hex_centers = cycle_centers(hexs, pos)
Mhex = hexagon_adjmat(hexs)
if not save_hexs_only:
np.save(path.join(outdir,f'M_hex-{nn}_{m}_{n}.npy'), Mhex)
np.save(path.join(outdir,f'M_atoms-{nn}_{m}_{n}.npy'), M)
np.save(path.join(outdir,f'ring_centers-{nn}_{m}_{n}.npy'), ring_centers)
np.save(path.join(outdir,f'ring_lengths-{nn}_{m}_{n}.npy'), ring_lengths)
np.save(path.join(outdir,f'hex_centers-{nn}_{m}_{n}.npy'), hex_centers)
if save_explicit_rings:
rings_global = [[iatoms[i] for i in c] for c in rings] #list of all rings in subsample with globally indexed atoms
hexs_global = np.array([[iatoms[i] for i in h] for h in hexs]) #list of hexagons in subsamp with globally indexed atoms
if not save_hexs_only:
with open(path.join(outdir, f'cycles-{nn}_{m}_{n}.pkl'), 'wb') as fo:
pickle.dump(rings_global, fo)
# save hexs separately because this will make my life easier to determine
# which atoms are in crystalline clusters
np.save(path.join(outdir, f'hexs-{nn}_{m}_{n}.npy'), hexs_global)
def get_a(datadir, nn, prefix='M_hex'):
"""Estimates the number of paritions along one direction from the output of an MPI-parallelised ring
analysis job (i.e. variable `a` from `get_rings_from_subsamp`), using the number of 'M_hex' files by default.
Assumes that all of the output files are formated as 'M_hex-nn_m_n.npy'.
Another set of files can be used (other than 'Mhex') if prefix is changed."""
ref_files = glob(path.join(datadir, f'{prefix}-{nn}_*.npy'))
nvals = [int(f.split('_')[-1].split('.')[0]) for f in ref_files]
return max(nvals) + 1
def rebuild_rings(nn,datadir=None):
"""Reconstructs rings from subsampled MPI runs:
* Places all of the rings (identified by their center of mass) into a single array
* Creates a single array of all of the ring lengths, ordered in the same as the ring
centers array.
* Constructs the adjacency matrix of all hexagons in the structure and uses it to
determine which hexagons are crystalline
This whole procedure basically removes redundant rings from overlapping subsamples and stitches the
hexagon network of the full structure back together using the local hexagon adjacency matrices.
If `explicit_rings` is set to `True`, this function also creates a list of all of the rings in the structure
where the ring is represented by the list of the global indices of it component atoms"""
from qcnico.jitted_cluster_utils import get_clusters
if datadir is None:
datadir = f'sample-{nn}'
a = get_a(datadir,nn)
slice_inds = cartesian_product(np.arange(a),np.arange(a))
start = perf_counter()
m,n = slice_inds[0,:]
print(f'Initialising hash maps (m,n) = ({m,n})',flush=True)
hex_pos_global = {tuple(r):k for k,r in enumerate(np.load(path.join(datadir,f'hex_centers-{nn}_{m}_{n}.npy')))} # global hashtable mapping hexagon centers to integer indices
all_pos_global = {tuple(r):k for k,r in enumerate(np.load(path.join(datadir,f'ring_centers-{nn}_{m}_{n}.npy')))} # global hashtable mapping ring centers to integer indices
all_lengths = np.load(path.join(datadir, f'ring_lengths-{nn}_{m}_{n}.npy'))
M = np.load(path.join(datadir, f'M_hex-{nn}_{m}_{n}.npy'))
neighb_list = {k:tuple(M[k,:].nonzero()[0]) for k in range(M.shape[0])}
ncentres_tot = M.shape[0]
ncentres_all_tot = all_lengths.shape[0]
assert len(hex_pos_global) == ncentres_tot, f'Mismatch between number of centers ({hex_pos_global.shape[0]}) and shape of hexagon adjacency matrix {M.shape}!'
print('Done! Commencing loop over other subsamples...',flush=True)
for mn in slice_inds[1:]:
m,n = mn
print(f'\n------ {(m,n)} ------',flush=True)
hex_pos = np.load(path.join(datadir, f'hex_centers-{nn}_{m}_{n}.npy'))
print(f'{hex_pos.shape[0]} distinct crystalline centers.', flush=True)
local_map_hex = {k:-1 for k in range(hex_pos.shape[0])} # hashtable that maps centre indices local to the NPY being processed to their global index (i.e. in `hex_pos_global`)
print('Loop 1: ', end='')
# first, update the global hashtable to properly index centers in subsample (m,n)
for k, r in enumerate(hex_pos):
r = tuple(r)
if r in hex_pos_global:
#print(f'* {r} in hex_pos_global *', flush=True)
local_map_hex[k] = hex_pos_global[r] # if this centre has been seen, add its global index to the local hashmap
else:
#print(f'~ Adding {r} to hex_pos_global ~', flush=True)
hex_pos_global[r] = ncentres_tot # if this centre hasn't yet been seen; assign a new index to it in global hashmap
local_map_hex[k] = ncentres_tot # idem for local hashmap
ncentres_tot += 1 # prepare index for next unseen centre
print('Done!',flush=True)
vals = np.array(local_map_hex.values())
print('Loop 2: ', end='',flush=True)
# next, update neighbour list using global hashmap
M = np.load(path.join(datadir, f'M_hex-{nn}_{m}_{n}.npy'))
for k in range(hex_pos.shape[0]):
k_global = local_map_hex[k]
ineighbs_local = tuple(M[:,k].nonzero()[0])
# handle case
if k_global in neighb_list:
neighb_list[k_global] = neighb_list[k_global] + tuple(local_map_hex[p] for p in ineighbs_local)
else:
neighb_list[k_global] = tuple(local_map_hex[p] for p in ineighbs_local)
print('Done!',flush=True)
print('Loop 3 (all rings): ', end='', flush=True)
all_pos = np.load(path.join(datadir, f'ring_centers-{nn}_{m}_{n}.npy'))
lengths = np.load(path.join(datadir, f'ring_lengths-{nn}_{m}_{n}.npy'))
new_lengths_local = []
for k, r in enumerate(all_pos):
r = tuple(r)
if r in all_pos_global:
continue
else:
all_pos_global[r] = ncentres_all_tot
new_lengths_local.append(lengths[k]) # ring lengths are sorted in same order as ring centres
ncentres_all_tot += 1
all_lengths = np.hstack((all_lengths,new_lengths_local))
print(f'Done! Added {len(new_lengths_local)} rings to global hashmap.', flush=True)
end = perf_counter()
print(f'\n**** Building hashtables took {end-start} seconds. ****\n',flush=True)
print('Constructing global hexagon adjacency matrix...', flush=True)
start = perf_counter()
Mglobal = np.zeros((ncentres_tot,ncentres_tot),dtype=bool)
isnucleus = np.zeros(ncentres_tot,dtype=bool)
isweird = np.zeros(ncentres_tot,dtype=bool)
for k in range(ncentres_tot):
ineighbs = neighb_list[k]
Mglobal[k,ineighbs] = True
Mglobal[ineighbs,k] = True
nb_neighbs = np.unique(ineighbs).shape[0]
if nb_neighbs == 6:
isnucleus[k] = True
elif nb_neighbs > 6:
isweird[k] = True
end = perf_counter()
print(f'**** Done! [{end - start} seconds] ****\nSaving stuff.', flush=True)
np.save(path.join(datadir, f'hex_global-{nn}.npy'),Mglobal)
with open(path.join(datadir, f'centres_hashmap-{nn}.pkl'), 'wb') as fo:
pickle.dump(hex_pos_global,fo)
with open(path.join(datadir, f'neighbs_dict-{nn}.pkl'), 'wb') as fo:
pickle.dump(neighb_list,fo)
nuclei = isnucleus.nonzero()[0]
print(f'*** Found {nuclei.shape[0]} crystalline nuclei ***', flush=True)
if isweird.sum() > 0:
weird = isweird.nonzero()[0]
print(f'!!!! Foundi {weird.shape[0]} weird nuclei !!!! Printing their number of neighbours now: ', flush=True)
for w in weird:
print(f'{w} --> {Mglobal[w,:].sum()}', flush=True)
print('Searching for clusters...',flush=True)
start = perf_counter()
Mglobal = sparse.csr_array(Mglobal.astype(np.int8)) #use sparse CSR matrix: DRAMATICALLY speeds up matrix product
nuclei_neighbs = np.unique(Mglobal[nuclei,:].nonzero()[1])
Mglobal2 = Mglobal @ Mglobal
nuclei_next_neighbs = np.unique(Mglobal2[nuclei,:].nonzero()[1])
strict_6c = set(np.concatenate((nuclei,nuclei_neighbs,nuclei_next_neighbs)))
cluster_start = perf_counter()
print(f'[{cluster_start - start} seconds later] Starting `get_clusters`...',flush=True)
Mglobal = Mglobal.tolil() #convert to LIL format: fast row-slicing and efficient updates to sparsity structure
crystalline_clusters = get_clusters(nuclei, Mglobal.toarray(), strict_6c)
end = perf_counter()
print(f'**** Done! Total time = {end - start} seconds. Time spent in `get_cluster` = {end - cluster_start} seconds ****',flush=True)
cluster_sizes = np.array([len(c) for c in crystalline_clusters])
np.save(path.join(datadir, f'cryst_cluster_sizes-{nn}.npy'),cluster_sizes)
print('Building all_centres to match order in `all_lengths`...')
start = perf_counter()
all_centres = np.zeros((all_lengths.shape[0], 2))
for r, k in all_pos_global.items():
all_centres[k] = r
end = perf_counter()
np.save(path.join(datadir, f'all_ring_centers-{nn}.npy'), all_centres)
np.save(path.join(datadir, f'all_ring_lengths-{nn}.npy'), all_lengths)
print(f'Done! Total time = {end - start} seconds.',flush=True)
with open(path.join(datadir, f'clusters-{nn}.pkl'), 'wb') as fo:
pickle.dump(crystalline_clusters,fo)
def ring_stats_rebuild(datadir, xlim=np.array([-np.inf,np.inf]), ylim=np.array([-np.inf, np.inf]),nn=0):
rl_filename = f'all_ring_lengths-{nn}.npy'
cluster_filename = f'clusters-{nn}.pkl'
all_ring_lengths = np.load(path.join(datadir, rl_filename))
ring_lengths = np.arange(3,10)
ring_stats = np.zeros(ring_lengths.shape[0] + 1) # add one slot to accomodate for 6i/6c distinction
with open(path.join(datadir + cluster_filename), 'rb') as fo:
cryst_clusters = pickle.load(fo)
cryst_hexs = cryst_clusters[0].union(*cryst_clusters[1:])
if np.any(~np.isinf(xlim)) or np.any(~np.isinf(ylim)):
all_centres = np.load(path.join(datadir, f'all_ring_centers-{nn}.npy'))
# find inds of all n-rings that lie in the desired region
x_mask = (all_centres[:,0] >= xlim[0]) * (all_centres[:,0] <= xlim[1])
y_mask = (all_centres[:,1] >= ylim[0]) * (all_centres[:,1] <= ylim[1])
mask = x_mask * y_mask
all_ring_lengths = all_ring_lengths[mask] #keep only the rings in region of interest
# find inds of hexagons that lie in the desired region
with open(path.join(datadir, f'centres_hashmap-{nn}.pkl'), 'rb') as fo:
hex_centres_dict = pickle.load(fo)
# this loads the dictionary into an array which preserves the hex centre --> index mapping
hex_centres = np.zeros((len(hex_centres_dict),2))
for r, k in hex_centres_dict.items():
hex_centres[k] = r
x_mask = (hex_centres[:,0] >= xlim[0]) * (hex_centres[:,0] <= xlim[1])
y_mask = (hex_centres[:,1] >= ylim[0]) * (hex_centres[:,1] <= ylim[1])
mask = x_mask * y_mask
hexs_filtered = set(mask.nonzero()[0])
nhexs = (all_ring_lengths == 6).sum()
print('Number of hexagons match: ', nhexs == len(hexs_filtered)) #sanity check
# finally keep only crsytalline heaxgons that lie in the desired region
cryst_hexs = cryst_hexs & hexs_filtered
nb_6c = len(cryst_hexs)
nb_6i = (all_ring_lengths == 6).sum() - nb_6c
for k in range(3):
ring_stats[k] = (all_ring_lengths == 3+k).sum()
# store 6c and 6i in 'wrong' order (wrt fig in Tian paper) bc plotting function swaps them anyways
ring_stats[3] = nb_6i
ring_stats[4] = nb_6c
for k, n in enumerate(ring_lengths[4:]): #ring lengths starts at 3; only consider ring lenghts > 6 here
ring_stats[5+k] = (all_ring_lengths == n).sum()
return ring_stats
def crystalline_atoms(full_pos, nn,datadir=None):
"""Generates a mask `m` filtering which atoms in a given structure which belong to a crystalline cluster from the
REBUILT output of MPI parallel ring analysis: m[n] = True iff nth atom is in a crystalline hexagon."""
if datadir is None:
datadir = f'sample-{nn}'
N = full_pos.shape[0]
# build list of all crystalline hexagons
with open(path.join(datadir, f'clusters-{nn}.pkl'), 'rb') as fo:
cryst_clusters = pickle.load(fo)
cryst_hexs = cryst_clusters[0].union(*cryst_clusters[1:])
crystalline_mask = np.zeros(N,dtype=bool)
# obtain global hashmap hex center ---> hex index
with open(path.join(datadir, f'centres_hashmap-{nn}.pkl'), 'rb') as fo:
hex_centres_hashmap = pickle.load(fo)
hex_centres = np.zeros((len(hex_centres_hashmap),2))
for r, k in hex_centres_hashmap.items():
hex_centres[k] = r
cryst_centres = hex_centres[list(cryst_hexs)]
cryst_centres = set(tuple(r) for r in cryst_centres)
# loop over all local outputs to determine which atoms are in the cryst clusters
a = get_a(datadir,nn)
slice_inds = cartesian_product(np.arange(a),np.arange(a))
for mn in slice_inds:
m,n = mn
local_hex_centres = np.load(path.join(datadir, f'hex_centers-{nn}_{m}_{n}.npy'))
local_hex_atoms = np.load(path.join(datadir, f'hexs-{nn}_{m}_{n}.npy'))
for k,r in enumerate(local_hex_centres):
r = tuple(r)
if r in cryst_centres:
icryst = local_hex_atoms[k]
crystalline_mask[icryst] = True
return crystalline_mask