diff --git a/src/runner.py b/src/runner.py index 2c00a93..43255d1 100644 --- a/src/runner.py +++ b/src/runner.py @@ -1,40 +1,46 @@ import os -from multiprocessing import Pool +import asyncio +import concurrent.futures +import logging +import sys from running.args_parse import get_args -from running.main import run, RunParams +from running.run import Run, RunParams from hopfield.input import read_data from running.paths import Paths from storage.data_storage import * - -def run_wrapper(arg_list): - run(arg_list[0], arg_list[1], arg_list[2]) +async def run_series(params_list, stores_list, root_path): + log = logging.getLogger("run_series") + executor = concurrent.futures.ProcessPoolExecutor() + event_loop = asyncio.get_event_loop() + runs = [Run(params_list[i], stores_list[i], root_path) for i in range(0, len(params_list))] + run_tasks = [event_loop.run_in_executor(executor, runs[i].begin) + for i in range(0, len(runs))] + log.info(f"Preparing {len(run_tasks)} runs to run") + await asyncio.wait(run_tasks) def main(): + logging.basicConfig( + level=logging.INFO, + format='PID %(process)5s %(name)18s: %(message)s', + stream=sys.stderr + ) root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) print("Root at: " + root) paths = Paths(root) data = read_data(fr"{paths.input()}\burma14.txt") args = get_args() - process_input_args = [] data_store = DataStorage(paths.results(), args.tag) - - run_index = 0 - for seed in args.seeds: - for size_adj in args.size_adjs: - process_input_args.append([ - RunParams( - seed, args.steps, size_adj, data, args.freq, args.tag, args.images, args.video, - paths), - data_store.open_run_store(run_index), - root]) - run_index += 1 - - pool = Pool() - pool.map(run_wrapper, process_input_args) + event_loop = asyncio.get_event_loop() + series_params = [RunParams(seed, args.steps, size_adj, data, args.freq, + args.freq, args.tag, args.images, args.video) + for seed in args.seeds for size_adj in args.size_adjs] + run_stores = [data_store.open_run_store(run_index) + for run_index in range(len(args.seeds) * len(args.size_adjs))] + event_loop.run_until_complete(run_series(series_params, run_stores, root)) if __name__ == '__main__': diff --git a/src/running/main.py b/src/running/main.py deleted file mode 100644 index d33254e..0000000 --- a/src/running/main.py +++ /dev/null @@ -1,62 +0,0 @@ -import sys -import time -import os - -from hopfield.hopfield_np import HopfieldNet -from hopfield.input import distance_matrix, normalize, normalize_cords -from running.paths import Paths -from storage.image_generator import GraphicalGenerator - - -class RunParams: - def __init__(self, seed, steps, size_adj, data, freq, tag, do_images, do_video, paths: Paths): - self.seed = seed - self.steps = steps - self.size_adj = size_adj - self.data = data - self.freq = freq - self.tag = tag - self.do_images = do_images - self.do_video = do_video - self.paths = paths - - -def run(params: RunParams, run_store, root_path): - net, normalized_distances = initialize(params) - run_store.store_net_config(net.get_net_configuration()) - graphical_generator = GraphicalGenerator(run_store, - os.path.join(root_path, "ffmpeg")) - - print("\nAnnealing network") - optimize_network(run_store, params.freq, net, params.steps) - print("\nAnnealing done!\n") - - if params.do_images or params.do_video: - graphical_generator.generate_run_images(params, normalize_cords(params.data), - normalized_distances) - if params.do_video: - graphical_generator.generate_run_video(params) - - print("Run Ended\n") - - -def initialize(params): - print(f"Seed: {params.seed}; Steps: {params.steps}; " - f"Size_Adj: {params.size_adj}; Freq: {params.freq}") - normalized_distances = normalize(distance_matrix(params.data)) - net = HopfieldNet(normalized_distances, params.seed, params.size_adj, params.paths) - - return net, normalized_distances - - -def optimize_network(runStore, freq, net, steps): - old = time.time() - for step in range(0, steps): - aligned_step = '{:>5}'.format(step) - sys.stdout.write(f"Step: {aligned_step} Time: {time.time() - old:.2}\r") - old = time.time() - net.update() - - if step % freq == 0: - runStore.add_data_point(net.get_net_state()) - runStore.commit_data() diff --git a/src/running/run.py b/src/running/run.py new file mode 100644 index 0000000..e1b5560 --- /dev/null +++ b/src/running/run.py @@ -0,0 +1,65 @@ +import sys +import time +import os +import logging + +from hopfield.hopfield_np import HopfieldNet +from hopfield.input import distance_matrix, normalize, normalize_cords +from running.paths import Paths +from storage.image_generator import GraphicalGenerator + +class RunParams: + def __init__(self, seed, steps, size_adj, data, freq, tag, do_images, do_video, paths: Paths): + self.seed = seed + self.steps = steps + self.size_adj = size_adj + self.data = data + self.freq = freq + self.tag = tag + self.do_images = do_images + self.do_video = do_video + self.paths = paths + +class Run: + def __init__(self, run_params, run_store, root_path): + self.run_params = run_params + self.run_store = run_store + self.root_path = root_path + + def begin(self): + net, normalized_distances = self.initialize() + self.run_store.store_net_config(net.get_net_configuration()) + graphical_generator = GraphicalGenerator(self.run_store, + os.path.join(self.root_path, "ffmpeg")) + + print("Annealing network") + self.optimize_network(net) + print("Annealing done!") + + if self.run_params.do_images or self.run_params.do_video: + graphical_generator.generate_run_images(self.run_params, normalize_cords(self.run_params.data), + normalized_distances) + if self.run_params.do_video: + graphical_generator.generate_run_video(self.run_params) + + print("Run Ended\n") + + def initialize(self): + print(f"Seed: {self.run_params.seed}; Steps: {self.run_params.steps}; " + f"Size_Adj: {self.run_params.size_adj}; Freq: {self.run_params.freq}") + normalized_distances = normalize(distance_matrix(self.run_params.data)) + net = HopfieldNet(normalized_distances, self.run_params.seed, self.run_params.size_adj, self.run_params.paths) + + return net, normalized_distances + + def optimize_network(self, net): + old = time.time() + for step in range(0, self.run_params.steps): + aligned_step = '{:>5}'.format(step) + sys.stdout.write(f"Step: {aligned_step} Time: {time.time() - old:.2}\r") + old = time.time() + net.update() + + if step % self.run_params.freq == 0: + self.run_store.add_data_point(net.get_net_state()) + self.run_store.commit_data()