Skip to content

Commit 23fd316

Browse files
authored
Merge pull request simopt-admin#44 from simopt-admin/dev_wagrocho_multiprocessing
Merge multiprocessing into main
2 parents 6c889e7 + 6c1f656 commit 23fd316

File tree

1 file changed

+97
-81
lines changed

1 file changed

+97
-81
lines changed

simopt/experiment_base.py

Lines changed: 97 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import csv
2020
import itertools
2121
from mrg32k3a.mrg32k3a import MRG32k3a
22-
import multiprocessing
23-
from multiprocessing import Process
22+
from multiprocessing import Pool
2423

2524

2625
from .base import Solution
@@ -371,6 +370,7 @@ class ProblemSolver(object):
371370
file_name_path : str, optional
372371
Path of .pickle file for saving ``experiment_base.ProblemSolver`` objects.
373372
"""
373+
374374
def __init__(self, solver_name=None, problem_name=None, solver_rename=None, problem_rename=None, solver=None, problem=None, solver_fixed_factors=None, problem_fixed_factors=None, model_fixed_factors=None, file_name_path=None):
375375
"""There are two ways to create a ProblemSolver object:
376376
1. Provide the names of the solver and problem to look up in ``directory.py``.
@@ -446,12 +446,11 @@ def run(self, n_macroreps):
446446
"""
447447
print("Running Solver", self.solver.name, "on Problem", self.problem.name + ".")
448448

449+
# Initialize variables
449450
self.n_macroreps = n_macroreps
450-
self.timings = []
451-
# Create variables for recommended solutions and intermediate budgets
452-
# so we can append to them in parallel.
453-
self.all_recommended_xs = multiprocessing.Manager().dict()
454-
self.all_intermediate_budgets = multiprocessing.Manager().dict()
451+
self.all_recommended_xs = [None] * n_macroreps
452+
self.all_intermediate_budgets = [None] * n_macroreps
453+
self.timings = [None] * n_macroreps
455454

456455
# Create, initialize, and attach random number generators
457456
# Stream 0: reserved for taking post-replications
@@ -465,42 +464,28 @@ def run(self, n_macroreps):
465464
# Streams 3, 4, ..., n_macroreps + 2: reserved for
466465
# macroreplications
467466
rng0 = MRG32k3a(s_ss_sss_index=[2, 0, 0]) # Currently unused.
468-
rng1 = MRG32k3a(s_ss_sss_index=[2, 1, 0])
469-
rng2 = MRG32k3a(s_ss_sss_index=[2, 2, 0])
470-
rng3 = MRG32k3a(s_ss_sss_index=[2, 3, 0])
471-
self.solver.attach_rngs([rng1, rng2, rng3])
467+
rng_list = [MRG32k3a(s_ss_sss_index=[2, i + 1, 0]) for i in range(3)]
468+
self.solver.attach_rngs(rng_list)
472469

473470
# Start a timer
474-
tic = time.perf_counter()
475-
476-
# If we're only doing one macroreplication or we have fewer than 4 cores, run macroreps in serial.
477-
# It just isn't worth the overhead to run in parallel.
478-
if n_macroreps == 1 or os.cpu_count() < 4:
471+
self.function_start = time.time()
472+
473+
print("Starting macroreplications in parallel")
474+
with Pool() as process_pool:
475+
# Start the macroreplications in parallel (async)
476+
result = process_pool.map_async(self.run_multithread, range(n_macroreps))
477+
# Wait for the results to be returned (or 1 second)
478+
while (not result.ready()):
479+
# Update status bar here
480+
result.wait(1)
481+
482+
# Grab all the data out of the result
479483
for mrep in range(n_macroreps):
480-
self.run_multithread(mrep)
481-
else:
482-
# Run n_macroreps of the solver on the problem.
483-
# Report recommended solutions and corresponding intermediate budgets.
484-
# Create an array of Process objects, one for each macroreplication.
485-
Processes = [Process(target=self.run_multithread, args=(mrep,)) for mrep in range(self.n_macroreps)]
486-
# Start each process.
487-
for mrep in range(self.n_macroreps):
488-
Processes[mrep].start()
489-
# Wait for each process to finish.
490-
for mrep in range(self.n_macroreps):
491-
Processes[mrep].join()
492-
# Stop the threads.
493-
for mrep in range(self.n_macroreps):
494-
Processes[mrep].terminate()
484+
self.all_recommended_xs[mrep], self.all_intermediate_budgets[mrep], self.timings[mrep] = result.get()[mrep]
485+
print("Finished running {} macroreplications in {} seconds.".format(n_macroreps, round(time.time() - self.function_start, 3)))
495486

496-
# Stop the timer.
497-
toc = time.perf_counter()
498-
# Print the total runtime.
499-
print(f"Total runtime: {toc - tic:0.4f} seconds ({(toc - tic) / self.n_macroreps:0.4f} seconds per macroreplication)\r\n")
500-
501-
# Convert the budgets and solutions into lists.
502-
self.all_recommended_xs = [self.all_recommended_xs[i] for i in range(len(self.all_recommended_xs.keys()))]
503-
self.all_intermediate_budgets = [self.all_intermediate_budgets[i] for i in range(len(self.all_intermediate_budgets.keys()))]
487+
# Delete stuff we don't need to save
488+
del self.function_start
504489

505490
# Save ProblemSolver object to .pickle file.
506491
self.record_experiment_results()
@@ -522,16 +507,13 @@ def run_multithread(self, mrep):
522507
tic = time.perf_counter()
523508
recommended_solns, intermediate_budgets = self.solver.solve(problem=self.problem)
524509
toc = time.perf_counter()
510+
runtime = toc - tic
511+
print(f"Macroreplication {mrep + 1}: Finished Solver {self.solver.name} on Problem {self.problem.name} in {runtime:0.4f} seconds.")
525512

526-
# Record the run time of the macroreplication.
527-
self.timings.append(toc - tic)
528-
# Trim solutions recommended after final budget.
513+
# Trim the recommended solutions and intermediate budgets
529514
recommended_solns, intermediate_budgets = trim_solver_results(problem=self.problem, recommended_solns=recommended_solns, intermediate_budgets=intermediate_budgets)
530-
# Extract decision-variable vectors (x) from recommended solutions.
531-
# Record recommended solutions and intermediate budgets.
532-
self.all_recommended_xs[mrep] = [solution.x for solution in recommended_solns]
533-
self.all_intermediate_budgets[mrep] = intermediate_budgets
534-
print(f"Macroreplication {mrep + 1}: Finished Solver {self.solver.name} on Problem {self.problem.name} in {toc - tic:0.4f} seconds.")
515+
# Return tuple (rec_solns, int_budgets, runtime)
516+
return ([solution.x for solution in recommended_solns], intermediate_budgets, runtime)
535517

536518
def check_run(self):
537519
"""Check if the experiment has been run.
@@ -561,47 +543,79 @@ def post_replicate(self, n_postreps, crn_across_budget=True, crn_across_macrorep
561543
True if CRN used for post-replications at solutions recommended on different
562544
macroreplications, otherwise False.
563545
"""
546+
print("Setting up {} postreplications for {} macroreplications of {} on {}.".format(n_postreps, self.n_macroreps, self.solver.name, self.problem.name))
547+
564548
self.n_postreps = n_postreps
565549
self.crn_across_budget = crn_across_budget
566550
self.crn_across_macroreps = crn_across_macroreps
567-
# Create, initialize, and attach RNGs for model.
568-
# Stream 0: reserved for post-replications.
569-
# Skip over first set of substreams dedicated for sampling x0 and x*.
570-
baseline_rngs = [MRG32k3a(s_ss_sss_index=[0, self.problem.model.n_rngs + rng_index, 0]) for rng_index in range(self.problem.model.n_rngs)]
571-
# Initialize matrix containing
572-
# all postreplicates of objective,
573-
# for each macroreplication,
574-
# for each budget.
575-
self.all_post_replicates = [[[] for _ in range(len(self.all_intermediate_budgets[mrep]))] for mrep in range(self.n_macroreps)]
576-
# Simulate intermediate recommended solutions.
551+
# Initialize variables
552+
self.all_post_replicates = [None] * self.n_macroreps
577553
for mrep in range(self.n_macroreps):
578-
print(f"Postreplicating macroreplication {mrep + 1} of {self.n_macroreps} of Solver {self.solver.name} on Problem {self.problem.name}.")
579-
for budget_index in range(len(self.all_intermediate_budgets[mrep])):
580-
x = self.all_recommended_xs[mrep][budget_index]
581-
fresh_soln = Solution(x, self.problem)
582-
fresh_soln.attach_rngs(rng_list=baseline_rngs, copy=False)
583-
self.problem.simulate(solution=fresh_soln, m=self.n_postreps)
584-
# Store results
585-
self.all_post_replicates[mrep][budget_index] = list(fresh_soln.objectives[:fresh_soln.n_reps][:, 0]) # 0 <- assuming only one objective
586-
if crn_across_budget:
587-
# Reset each rng to start of its current substream.
588-
for rng in baseline_rngs:
589-
rng.reset_substream()
590-
if crn_across_macroreps:
591-
# Reset each rng to start of its current substream.
592-
for rng in baseline_rngs:
593-
rng.reset_substream()
594-
else:
595-
# Advance each rng to start of
596-
# substream = current substream + # of model RNGs.
597-
for rng in baseline_rngs:
598-
for _ in range(self.problem.model.n_rngs):
599-
rng.advance_substream()
600-
# Store estimated objective for each macrorep for each budget.
601-
self.all_est_objectives = [[np.mean(self.all_post_replicates[mrep][budget_index]) for budget_index in range(len(self.all_intermediate_budgets[mrep]))] for mrep in range(self.n_macroreps)]
554+
self.all_post_replicates[mrep] = [] * len(self.all_intermediate_budgets[mrep])
555+
self.timings = [None] * self.n_macroreps
556+
557+
self.function_start = time.time()
558+
559+
print("Starting postreplications in parallel")
560+
with Pool() as process_pool:
561+
# Start the macroreplications in parallel (async)
562+
result = process_pool.map_async(self.post_replicate_multithread, range(self.n_macroreps))
563+
# Wait for the results to be returned (or 1 second)
564+
while (not result.ready()):
565+
# Update status bar here
566+
result.wait(1)
567+
568+
# Grab all the data out of the result
569+
for mrep in range(self.n_macroreps):
570+
self.all_post_replicates[mrep], self.timings[mrep] = result.get()[mrep]
571+
572+
# # The all post replicates is tricky because it is a dictionary of lists of lists
573+
# # We need to convert it to a list of lists of lists
574+
# self.all_post_replicates = [self.all_post_replicates[i] for i in range(len(self.all_post_replicates.keys()))]
575+
# Store estimated objective for each macrorep for each budget.
576+
self.all_est_objectives = [[np.mean(self.all_post_replicates[mrep][budget_index]) for budget_index in range(len(self.all_intermediate_budgets[mrep]))] for mrep in range(self.n_macroreps)]
577+
print("Finished running {} postreplications in {} seconds.".format(self.n_macroreps, round(time.time() - self.function_start, 3)))
578+
579+
# Delete stuff we don't need to save
580+
del self.function_start
581+
602582
# Save ProblemSolver object to .pickle file.
603583
self.record_experiment_results()
604584

585+
def post_replicate_multithread(self, mrep):
586+
print(f"Macroreplication {mrep + 1}: Starting postreplications for {self.solver.name} on {self.problem.name}.")
587+
# Create RNG list for the macroreplication.
588+
if self.crn_across_macroreps:
589+
# Use the same RNGs for all macroreps.
590+
baseline_rngs = [MRG32k3a(s_ss_sss_index=[0, self.problem.model.n_rngs + rng_index, 0]) for rng_index in range(self.problem.model.n_rngs)]
591+
else:
592+
baseline_rngs = [MRG32k3a(s_ss_sss_index=[0, self.problem.model.n_rngs * (mrep + 1) + rng_index, 0]) for rng_index in range(self.problem.model.n_rngs)]
593+
594+
tic = time.perf_counter()
595+
596+
# Create an empty list for each budget
597+
post_replicates = []
598+
# Loop over all recommended solutions.
599+
for budget_index in range(len(self.all_intermediate_budgets[mrep])):
600+
x = self.all_recommended_xs[mrep][budget_index]
601+
fresh_soln = Solution(x, self.problem)
602+
# Attach RNGs for postreplications.
603+
# If CRN is used across budgets, then we should use a copy rather
604+
# than passing in the original RNGs.
605+
if (self.crn_across_budget):
606+
fresh_soln.attach_rngs(rng_list=baseline_rngs, copy=True)
607+
else:
608+
fresh_soln.attach_rngs(rng_list=baseline_rngs, copy=False)
609+
self.problem.simulate(solution=fresh_soln, m=self.n_postreps)
610+
# Store results
611+
post_replicates.append(list(fresh_soln.objectives[:fresh_soln.n_reps][:, 0])) # 0 <- assuming only one objective
612+
toc = time.perf_counter()
613+
runtime = toc - tic
614+
print(f"\t{mrep + 1}: Finished in {round(runtime, 3)} seconds")
615+
616+
# Return tuple (post_replicates, runtime)
617+
return (post_replicates, runtime)
618+
605619
def check_postreplicate(self):
606620
"""Check if the experiment has been postreplicated.
607621
@@ -2891,6 +2905,7 @@ def read_group_experiment_results(file_name_path):
28912905
groupexperiment = pickle.load(file)
28922906
return groupexperiment
28932907

2908+
28942909
def find_unique_solvers_problems(experiments):
28952910
"""Identify the unique problems and solvers in a collection
28962911
of experiments.
@@ -2918,6 +2933,7 @@ def find_unique_solvers_problems(experiments):
29182933
unique_problems.append(experiment.problem)
29192934
return unique_solvers, unique_problems
29202935

2936+
29212937
def find_missing_experiments(experiments):
29222938
"""Identify problem-solver pairs that are not part of a list
29232939
of experiments.

0 commit comments

Comments
 (0)