Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix problem when learner has no more points for BalancingLearner #214

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,34 @@ def strategy(self, strategy):
' strategy="npoints", or strategy="cycle" is implemented.'
)

def _ask_all_learners(self, total_points):
to_select = []
for index, learner in enumerate(self.learners):
# Take the points from the cache
if index not in self._ask_cache:
self._ask_cache[index] = learner.ask(n=1, tell_pending=False)
points, loss_improvements = self._ask_cache[index]
if not points: # cannot ask for more points
return to_select
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of things I don't understand here:

  • seems this return should be a continue; just because learner i could not give any more points does not mean that no other learners can give any!
  • I cannot see when this branch will ever be executed. learner.ask(1) is guaranteed to return a point. At the moment there is no way for a learner to indicate that it has "no more points". If a learner returns no points then it is in violation of the API and other stuff is liable to break

to_select.append(
((index, points[0]), (loss_improvements[0], -total_points[index]))
)
return to_select

def _ask_and_tell_based_on_loss_improvements(self, n):
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
to_select = []
for index, learner in enumerate(self.learners):
# Take the points from the cache
if index not in self._ask_cache:
self._ask_cache[index] = learner.ask(n=1, tell_pending=False)
points, loss_improvements = self._ask_cache[index]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here one would need a double break, but because that doesn't exist I make it into a function.

to_select.append(
((index, points[0]), (loss_improvements[0], -total_points[index]))
)

to_select = self._ask_all_learners(total_points)
if not to_select: # cannot ask for more points
break
# Choose the optimal improvement.
(index, point), (loss_improvement, _) = max(to_select, key=itemgetter(1))
total_points[index] += 1
selected.append(((index, point), loss_improvement))
self.tell_pending((index, point))

points, loss_improvements = map(list, zip(*selected))
points, loss_improvements = map(list, zip(*selected)) if selected else [], []
return points, loss_improvements

def _ask_and_tell_based_on_loss(self, n):
Expand All @@ -156,11 +163,12 @@ def _ask_and_tell_based_on_loss(self, n):
if index not in self._ask_cache:
self._ask_cache[index] = self.learners[index].ask(n=1)
points, loss_improvements = self._ask_cache[index]

if not points: # cannot ask for more points
break
selected.append(((index, points[0]), loss_improvements[0]))
self.tell_pending((index, points[0]))

points, loss_improvements = map(list, zip(*selected))
points, loss_improvements = map(list, zip(*selected)) if selected else [], []
return points, loss_improvements

def _ask_and_tell_based_on_npoints(self, n):
Expand All @@ -172,11 +180,13 @@ def _ask_and_tell_based_on_npoints(self, n):
if index not in self._ask_cache:
self._ask_cache[index] = self.learners[index].ask(n=1)
points, loss_improvements = self._ask_cache[index]
if not points: # cannot ask for more points
break
total_points[index] += 1
selected.append(((index, points[0]), loss_improvements[0]))
self.tell_pending((index, points[0]))

points, loss_improvements = map(list, zip(*selected))
points, loss_improvements = map(list, zip(*selected)) if selected else [], []
return points, loss_improvements

def _ask_and_tell_based_on_cycle(self, n):
Expand Down
12 changes: 11 additions & 1 deletion adaptive/tests/test_balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from adaptive.learner import BalancingLearner, Learner1D
from adaptive.learner import BalancingLearner, Learner1D, SequenceLearner
from adaptive.runner import simple


Expand Down Expand Up @@ -44,6 +44,16 @@ def test_distribute_first_points_over_learners(strategy):
assert len(set(i_learner)) == len(learners)


@pytest.mark.parametrize("strategy", strategies)
def test_asking_more_points_than_available(strategy):
def dummy(x):
return x

bl = BalancingLearner([SequenceLearner(dummy, range(5))], strategy=strategy)
bl.ask(100)
bl.ask(100)


@pytest.mark.parametrize("strategy", strategies)
def test_ask_0(strategy):
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
Expand Down