Skip to content

Commit

Permalink
Implement adapt parallelisation option in global_params
Browse files Browse the repository at this point in the history
  • Loading branch information
avigan committed Oct 4, 2024
1 parent ed97b33 commit c29c291
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
6 changes: 3 additions & 3 deletions ForMoSA/adapt/adapt_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def tpool_adapt(idx, global_params, wav_mod_nativ, res_mod_obs_merge, obs_name,
grid_photo[(..., ) + idx] = mod_photo


def adapt_grid(global_params, wav_obs_spectro, wav_obs_photo, res_mod_obs_merge, obs_name='', indobs=0, parallel=True):
def adapt_grid(global_params, wav_obs_spectro, wav_obs_photo, res_mod_obs_merge, obs_name='', indobs=0):
"""
Adapt the synthetic spectra of a grid to make them comparable with the data.
Expand Down Expand Up @@ -175,8 +175,8 @@ def adapt_grid(global_params, wav_obs_spectro, wav_obs_photo, res_mod_obs_merge,
def update(*a):
pbar.update()

if parallel:
ncpu = mp.cpu_count() // 2
if global_params.parallel:
ncpu = mp.cpu_count()
with ThreadPool(processes=ncpu, initializer=tpool_adapt_init, initargs=(grid_input_shape, grid_input_data, grid_spectro_shape, grid_spectro_data, grid_photo_shape, grid_photo_data)) as pool:
for idx in np.ndindex(shape):
pool.apply_async(tpool_adapt, args=(idx, global_params, wav_mod_nativ, res_mod_obs_merge, obs_name, indobs, attr['key'], attr['title'], values), callback=update)
Expand Down
14 changes: 10 additions & 4 deletions ForMoSA/main_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def diag_mat(rem=[], result=np.empty((0, 0))):
class GlobFile:
'''
Class that import all the parameters from the config file and make them GLOBAL FORMOSA VARIABLES.
Author: Paulina Palma-Bifani
'''

Expand All @@ -76,7 +76,7 @@ def __init__(self, config_file_path):
model_name = model_name[0]
self.model_name = model_name

if type(config['config_adapt']['wav_for_adapt']) != list: # Create lists if only one obs in the loop
if type(config['config_adapt']['wav_for_adapt']) != list: # Create lists if only one obs in the loop
# [config_adapt] (5)
self.wav_for_adapt = [config['config_adapt']['wav_for_adapt']]
self.adapt_method = [config['config_adapt']['adapt_method']]
Expand All @@ -101,6 +101,12 @@ def __init__(self, config_file_path):
self.logL_type = config['config_inversion']['logL_type']
self.wav_fit = config['config_inversion']['wav_fit']

# parallelisation of adapt
try:
self.parallel = config['config_adapt']['parallel']
except KeyError:
self.parallel = False

self.ns_algo = config['config_inversion']['ns_algo']
self.npoint = config['config_inversion']['npoint']

Expand All @@ -121,7 +127,7 @@ def __init__(self, config_file_path):
self.bb_R = config['config_parameter']['bb_R']

self.ck = None

# [config_nestle] (5, some mutually exclusive) (n_ prefix for params)
self.n_method = config['config_nestle']['method']
self.n_maxiter = eval(config['config_nestle']['maxiter'])
Expand Down Expand Up @@ -152,7 +158,7 @@ def __init__(self, config_file_path):
# self.p_init_MPI = config['config_pymultinest']['init_MPI']
# self.p_dump_callback = config['config_pymultinest']['dump_callback']
# self.p_use_MPI = config['config_pymultinest']['use_MPI']

# [config_dinesty] & [config_ultranest] CHECK THIS

# ## create OUTPUTS Sub-Directories: interpolated grids and results
Expand Down

0 comments on commit c29c291

Please sign in to comment.