-
Notifications
You must be signed in to change notification settings - Fork 17
Description
I was running large simulations using LBS and passing the data to BrahMap for on-the-fly map-making. The goal was to demonstrate the computational feasibility of using BrahMap for very large simulations. In order to prevent diluting the map-making, I tried placing as many detectors (or number of samples) as possible on a single node. While running these simulations, I faced two issues - the first one is not exactly a problem with LBS but I would like to discuss how we can handle it better. The second one is happening on the LBS side leading to an unexpected out-of-memory (OOM) error.
-
During initial tests on Galileo100 (384 GB RAM/node), the pipeline hit OOM errors with just 3 detectors per node. A manual profiling showed that
sim.prepare_pointings()allocated ~105 GB, followed by another ~138 GB from BrahMap preprocessing. Then as map-making is done, some intermediate arrays are created that eventually fill all the memory leading to OOM error. To fix this, I optimized my script at many places that also included deleting the quaternion arrayssim.observations[i].pointing_provider.bore2ecliptic_quats.quatsjust after doing preprocessing in BrahMap. This frees ~105 GB, allowing the jobs to run successfully.For this case, BrahMap consumes a large amount of memory, and to some extent, it can be optimized but it won't be enough to resolve the OOM errors for large runs. At the preprocessing step in BrahMap, we compute and store in the memory, the pointing indices, sine-cosine of pol angles, and some other quantities that are used multiple times during map-making. Precomputing these quantities reduces the computational load of map-making by a large margin. This is something I guess, is true for other map-makers as well and is not totally avoidable. For the specific pipeline and the configuration I was using, I noticed that storing pointing indices and pol angles for each sample would have been cheaper than keeping the quaternion array. For a very simple pipeline, all we need ultimately are the pointing indices and pol angles. So keeping them in the memory would be more optimal from computational and memory perspective. But I understand that it can be disruptive for complex pipeline.
-
Next I used the same script on perlmutter (512 GB RAM/node). Here I tested the script multiple times on 2 nodes, by placing 4 detectors per node. However, when I scaled the code to 100+ nodes (still keeping 4 detectors per node), the job failed with OOM error while calling
sim.prepare_pointings(). This is unexpected as increasing the number of nodes should not increase the memory occupancy of quaternion arrays or intermediate arrays proportionally on each node. The issue remained there even when I reduced the number of samples per rank by 20%. Then due to resource constraints, I was not able to investigate this issue further. While Save memory in PointingProvider.get_pointings #488 definitely looks good, I am not totally sure if it would solve the issue. And as of now, I don't have enough computing resources to test it for a very large run.
For the reference, here is a sample script that I was using:
For testing on nersc
### For testing on nersc
import sys
import gc
import numpy as np
from mpi4py import MPI
from astropy import units as u
import litebird_sim as lbs
import brahmap
comm = lbs.MPI_COMM_WORLD
comm_size = comm.Get_size()
comm_rank = comm.Get_rank()
if len(sys.argv) != 13:
if comm_rank == 0:
print(
f"The number of command line arguments must be {len(sys.argv) - 1}",
file=sys.stderr,
)
comm.Abort(1)
##################################
##### Setting the parameters #####
##################################
DEBUG_OUTPUT = False
CLEAN_QUAT = True
### Simulation parameters, to be read from command line
sim_params = {}
sim_params["spin_rate_rpm"] = float(sys.argv[1]) # 0.05
sim_params["spin_sun_angle_deg"] = float(sys.argv[2]) # alpha: 45.0
sim_params["precession_period_min"] = float(sys.argv[3]) # 192.348
sim_params["spin_boresight_angle_deg"] = (
95.0 - sim_params["spin_sun_angle_deg"]
)
sim_params["sampling_rate_hz"] = float(sys.argv[4])
sim_params["net_ukrts"] = float(sys.argv[5]) # 55.65
sim_params["fknee_mhz"] = float(sys.argv[6]) # 20.0
sim_params["alpha"] = float(sys.argv[7])
sim_params["chunk_days"] = float(sys.argv[8])
sim_params["total_days"] = float(sys.argv[9])
sim_params["job_id"] = int(sys.argv[10])
sim_params["task_arr_id"] = int(sys.argv[11])
sim_params["host"] = "nersc" # "nersc", "vivobook" #!
sim_params["comm_size"] = comm_size
sim_tag = sys.argv[12]
if comm_rank == 0:
print(f"comm_size = {comm_size}\n")
print(f"\n\n\n======== Execution started for tag {sim_tag} ========\n")
if DEBUG_OUTPUT and comm_rank == 0:
print("Setting up the simulation...")
### Other LBS parameters
lbs_params = {}
lbs_params["realization_set"] = range(1)
lbs_params["n_blocks_time"] = 48 #!
lbs_params["n_blocks_det"] = 4 #!
lbs_params["num_of_obs_per_detector"] = 23 #!
lbs_params["nside"] = 512 #!
lbs_params["delta_quat_sample_s"] = 60.0
SCAN_THE_SKY=True
lbs_params["scan_cmb_sky"] = True
lbs_params["seed_cmb"] = 123
lbs_params["noise_type"] = "one_over_f"
lbs_params["noise_cov_type"] = (
"diagonal"
)
lbs_params["toep_precond_maxiter"] = 100 #!
lbs_params["toep_precond_atol"] = 1.0e-6 #!
lbs_params["compute_chi2"] = False
lbs_params["dtype_float"] = np.float64
lbs_params["imo"] = lbs.Imo(
flatfile_location="/global/homes/a/aanand/brahmap_25_09_01/production06/v1.3/"
)
lbs_params["imo_version"] = "v1.3"
lbs_params["telescope"] = "MFT"
lbs_params["channel"] = "M2-119"
lbs_params["det_names_file"] = "detector_list_MFT_M2-119_8dets.txt"
### Input path for detector list and other quantities
if sim_params["host"] == "vivobook":
lbs_params["input_path"] = (
"/mnt/Data/Projects/uniroma2/coding/playground/brahmap_0825_01/production02/"
)
lbs_params["container_path"] = (
"/mnt/Data/Projects/uniroma2/coding/playground/brahmap_0825_01/production02/test_container/"
)
elif sim_params["host"] == "nersc":
lbs_params["input_path"] = "/global/homes/a/aanand/brahmap_25_09_01/production06/"
### Path of the dir on a typical scratch partition
lbs_params["container_path"] = (
"/pscratch/sd/a/aanand/brahmap_25_09_01/production06/"
)
else:
if comm_rank == 0:
print(
f"The script is not ready to be launched at host {sim_params['host']}",
file=sys.stderr,
)
comm.Abort(1)
### GLS parameters
gls_parameters = brahmap.LBSimGLSParameters(
use_iterative_solver=True,
isolver_threshold=1.0e-6,
isolver_max_iterations=600, #!
callback_function=None,
return_processed_samples=False,
return_hit_map=False,
)
### Creating I/O directories
lbs_base_path = (
lbs_params["container_path"]
+ f"/aux_lbs/temp_{sim_params['job_id']}_{sim_params['task_arr_id']}_{sim_tag}"
)
###################################################
##### Creating observations with litebird_sim #####
###################################################
if DEBUG_OUTPUT and comm_rank == 0:
print("Creating observations with litebird_sim...")
### Initializing the simulation
sim = lbs.Simulation(
name="brahmap",
base_path=lbs_base_path,
start_time=0,
duration_s=(sim_params["total_days"] * u.d).to("s").value,
random_seed=00, # Should not matter as we are going to update seeds before generating noise-only tods
imo=lbs_params["imo"],
mpi_comm=comm,
# numba_threads=None, # set using env var
numba_threading_layer="omp",
)
### Setting the scanning strategy
sim.set_scanning_strategy(
scanning_strategy=lbs.SpinningScanningStrategy(
spin_sun_angle_rad=np.deg2rad(sim_params["spin_sun_angle_deg"]),
spin_rate_hz=1.0 / (1 / sim_params["spin_rate_rpm"] * u.min).to("s").value,
precession_rate_hz=1.0
/ (sim_params["precession_period_min"] * u.min).to("s").value,
),
delta_time_s=lbs_params["delta_quat_sample_s"],
)
### Instrument definition
sim.set_instrument(
instrument=lbs.InstrumentInfo(
name="LB",
spin_boresight_angle_rad=np.deg2rad(sim_params["spin_boresight_angle_deg"]),
### All other parameters can be ignored
)
)
### Detector names
det_names = np.loadtxt(
lbs_params["input_path"] + lbs_params["det_names_file"],
dtype=str,
)
### Loading the detector info
det_list = []
for detector in det_names:
det_obj = lbs.DetectorInfo.from_imo(
url=f"/releases/{lbs_params['imo_version']}/satellite/{lbs_params['telescope']}/{lbs_params['channel']}/{detector}/detector_info",
imo=lbs_params["imo"],
)
det_obj.sampling_rate_hz = sim_params["sampling_rate_hz"]
det_obj.net_ukrts = sim_params["net_ukrts"]
det_obj.fknee_mhz = sim_params["fknee_mhz"]
det_obj.alpha = sim_params["alpha"]
det_list.append(det_obj)
if DEBUG_OUTPUT and comm_rank == 0:
print(f"numba_threads = {sim.numba_threads}")
print(f"numba_threading_layer = {sim.numba_threading_layer}")
print("Creating the observation object...")
### Create observations
sim.create_observations(
detectors=det_list,
num_of_obs_per_detector=lbs_params["num_of_obs_per_detector"],
n_blocks_time=lbs_params["n_blocks_time"],
n_blocks_det=lbs_params["n_blocks_det"],
split_list_over_processes=False,
tod_dtype=lbs_params["dtype_float"],
tods=[
lbs.TodDescription(
name="tod",
dtype=lbs_params["dtype_float"],
description="Signal",
),
],
)
if DEBUG_OUTPUT and comm_rank == 0:
print("Preparing the pointings...")
### Compute pointings
sim.prepare_pointings()
##############################
##### Preparing the TODs #####
##############################
### Scanning the cmb sky
if lbs_params["scan_cmb_sky"]:
if DEBUG_OUTPUT and comm_rank == 0:
print("Scanning the input sky...")
### Loading the channel info
ch_info = lbs.FreqChannelInfo.from_imo(
lbs_params["imo"],
f"/releases/{lbs_params['imo_version']}/satellite/{lbs_params['telescope']}/{lbs_params['channel']}/channel_info",
)
### Producing the input CMB maps
mbs_params = lbs.MbsParameters(
make_cmb=True,
seed_cmb=lbs_params["seed_cmb"],
cmb_r=0.0,
make_fg=False,
gaussian_smooth=True,
bandpass_int=False,
nside=lbs_params["nside"],
units="uK_CMB",
maps_in_ecliptic=False,
output_string="mbs_cmb",
# save=True,
)
mbs_obj = lbs.Mbs(
simulation=sim,
parameters=mbs_params,
channel_list=[ch_info],
)
if comm.rank == 0:
input_maps = mbs_obj.run_all()
else:
input_maps = None
# Distributing the maps to all MPI processes
input_maps = comm.bcast(input_maps, 0)
comm.barrier()
#######################################
##### Preparing the GLS operators #####
#######################################
if DEBUG_OUTPUT and comm_rank == 0:
print("Preparing the GLS operators...")
### Creating noise operators
if lbs_params["noise_cov_type"] == None: # noqa: E711
inv_cov = None
elif lbs_params["noise_cov_type"] == "diagonal":
inv_cov = brahmap.LBSim_InvNoiseCovLO_UnCorr(
sim.observations,
dtype=lbs_params["dtype_float"],
)
else:
print("Invalid noise covariance type!!!", file=sys.stderr)
comm.Abort(1)
### Processing the lbs samples
processed_samples = brahmap.LBSimProcessTimeSamples(
nside=lbs_params["nside"],
observations=sim.observations,
noise_weights=inv_cov.diag if inv_cov is not None else None,
dtype_float=lbs_params["dtype_float"],
)
### Defining the linear operators to be used in map-making
pointing_operator = brahmap.PointingLO(
processed_samples=processed_samples,
solver_type=gls_parameters.solver_type,
)
blockdiagprecond_operator = brahmap.BlockDiagonalPreconditionerLO(
processed_samples=processed_samples,
solver_type=gls_parameters.solver_type,
)
A = pointing_operator.T * inv_cov * pointing_operator
######################################################################
##### Producing noise tod and creating maps for each realization #####
######################################################################
if DEBUG_OUTPUT and comm_rank == 0:
print("Adding noise and doing map-making for each realization...")
for realization in lbs_params["realization_set"]:
# Note that realization index is supplied as the seed to the noise generator. This seed is used to create rng hierarchy for the detectors
try:
sim.nullify_tod(components="tod")
except Exception:
for obs in sim.observations:
obs.tod = np.zeros(
[obs.n_detectors, obs.n_samples],
dtype=lbs_params["dtype_float"],
)
if lbs_params["scan_cmb_sky"] and SCAN_THE_SKY:
### Scanning the sky
lbs.scan_map_in_observations(
sim.observations,
maps=input_maps[0][lbs_params["channel"]],
pointings_dtype=lbs_params["dtype_float"],
component="tod",
)
### Adding noise to the TOD
if lbs_params["noise_type"] == None: # noqa: E711
pass
elif (
lbs_params["noise_type"] == "white" or lbs_params["noise_type"] == "one_over_f"
):
lbs.add_noise_to_observations(
observations=sim.observations,
noise_type=lbs_params["noise_type"],
dets_random=sim.dets_random, # Will be ignored due to user_seed
user_seed=realization,
component="tod",
scale=1.0e6, # For K to uK conversion
)
else:
raise NotImplementedError(
f"{lbs_params['noise_type']} noise type is not implemented yet"
)
### Concatenating the tod
time_ordered_data = np.concatenate(
[obs.tod for obs in sim.observations],
axis=None,
)
if CLEAN_QUAT:
for obs in sim.observations:
del obs.tod
del obs.pointing_provider.bore2ecliptic_quats.quats
gc.collect()
b = inv_cov * time_ordered_data
b = pointing_operator.T * b
### Finally the map-making
num_iterations = 0
residual_arr = []
if gls_parameters.use_iterative_solver:
def callback_function(x, r, norm_residual):
global num_iterations
num_iterations += 1
global residual_arr
residual_arr.append(norm_residual)
map_vector, pcg_status = brahmap.math.cg(
A=A,
b=b,
atol=gls_parameters.isolver_threshold,
maxiter=gls_parameters.isolver_max_iterations,
M=blockdiagprecond_operator,
callback=callback_function,
parallel=True,
)
else:
pcg_status = 0
map_vector = blockdiagprecond_operator * time_ordered_data
output_maps = brahmap.separate_map_vectors(
map_vector=map_vector,
processed_samples=processed_samples,
)
del b
del time_ordered_data
time_stamp10 = MPI.Wtime()
if pcg_status != 0:
convergence_status = False
else:
convergence_status = True
sim.flush()