Skip to content

Commit ed33dc3

Browse files
committed
added test of EvalParallel2 and polish code
1 parent 67feff2 commit ed33dc3

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

cma/optimization_tools.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""
33
from __future__ import absolute_import, division, print_function #, unicode_literals
44
import sys
5+
import warnings
56
import numpy as np
67
from multiprocessing import Pool as ProcessingPool
78
# from pathos.multiprocessing import ProcessingPool
@@ -207,8 +208,10 @@ class EvalParallel2(object):
207208
208209
Examples:
209210
210-
>>> import cma
211211
>>> from cma.optimization_tools import EvalParallel2
212+
>>> for n_jobs in [None, -1, 0, 1, 2, 4]:
213+
... with EvalParallel2(cma.fitness_functions.elli, n_jobs) as eval_all:
214+
... res = eval_all([[1,2], [3,4]])
212215
>>> # class usage, don't forget to call terminate
213216
>>> ep = EvalParallel2(cma.fitness_functions.elli, 4)
214217
>>> ep([[1,2], [3,4], [4, 5]]) # doctest:+ELLIPSIS
@@ -244,11 +247,11 @@ class EvalParallel2(object):
244247
"""
245248
def __init__(self, fitness_function=None, number_of_processes=None):
246249
self.fitness_function = fitness_function
247-
self.processes = number_of_processes
248-
if self.processes is not None and self.processes <= 0:
249-
self.pool = None
250-
else:
250+
self.processes = number_of_processes # for the record
251+
if self.processes is None or self.processes > 0:
251252
self.pool = ProcessingPool(self.processes)
253+
else:
254+
self.pool = None
252255

253256
def __call__(self, solutions, fitness_function=None, args=(), timeout=None):
254257
"""evaluate a list/sequence of solution-"vectors", return a list
@@ -269,7 +272,7 @@ def __call__(self, solutions, fitness_function=None, args=(), timeout=None):
269272
warning_str = ("`fitness_function` must be a function, not a"
270273
" `lambda` or an instancemethod, in order to work with"
271274
" `multiprocessing` under Python 2")
272-
if sys.version[0] == '2': # not necessary anymore?
275+
if sys.version[0] == '2':
273276
if isinstance(fitness_function, type(self.__init__)):
274277
warnings.warn(warning_str)
275278
jobs = [self.pool.apply_async(fitness_function, (x,) + args)

cma/test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,15 @@ def various_doctests():
292292
293293
For VD- and VkD-CMA, see `cma.restricted_gaussian_sampler`.
294294
295+
>>> import sys
295296
>>> import cma
296297
>>> assert cma.interfaces.EvalParallel2 is not None
298+
>>> try:
299+
... with warnings.catch_warnings(record=True) as warn:
300+
... with cma.optimization_tools.EvalParallel2(cma.ff.elli) as eval_all:
301+
... res = eval_all([[1,2], [3,4]])
302+
... except:
303+
... assert sys.version[0] == '2'
297304
298305
"""
299306

0 commit comments

Comments
 (0)