Skip to content

Commit

Permalink
Add option to stop search once a solution is found (#13)
Browse files Browse the repository at this point in the history
*Rationale*: Even though our main use case for `syntheseus` so far has
been to run search for a fixed amount of time and analyze all routes in
the resulting search graph, there are some scenarios where we only care
about whether or not a route is found (e.g. benchmarking). This is
pretty simple to check so I thought it is worth supporting.

*Implementation*: previously all algorithms called `time_limit_reached`
each iteration to decide whether to stop early. I just augmented this
function to also check for stopping once a solution is found. This
involved:

1. Adding a `stop_on_first_solution` kwarg to the base algorithm class
(defaults to False, which is the most sensible default in my opinion)
2. Adding logic to `time_limit_reached` which checks
`self.stop_on_first_solution and graph.root_node.has_solution`. This
required adding a `graph` argument to the function (previously it took
no arguments), causing changes in all algorithm files.
3. Because of this added functionality, `time_limit_reached` no longer
seemed like a descriptive name for the function, so I changed it to
`should_stop_search`
4. I added a test to `syntheseus/tests/search/algorithms/test_base.py`
which tests the effectiveness of this kwarg for all algorithms (because
algorithm-specific tests all subclass the base test)

*Rejected alternative*: although it would be appealing to generalize
this to "stop after n solutions" or "stop after n diverse solutions"
this is not simple or quick to check so I don't think we should support
it out-of-the-box. If users want this then they could implement it
themselves in a subclass.

---------

Co-authored-by: Krzysztof Maziarz <[email protected]>
  • Loading branch information
AustinT and kmaziarz committed Jul 13, 2023
1 parent 4049576 commit 6465917
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.

### Added

- Add option to terminate search when the first solution is found ([#13](https://github.com/microsoft/syntheseus/pull/13)) ([@austint])
- Add code to extract routes in order found instead of by minimum cost ([#9](https://github.com/microsoft/syntheseus/pull/9)) ([@austint])
- Declare support for type checking ([#4](https://github.com/microsoft/syntheseus/pull/4)) ([@kmaziarz])

Expand Down
25 changes: 19 additions & 6 deletions syntheseus/search/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
unique_nodes: bool = False,
random_state: Optional[random.Random] = None,
prevent_repeat_mol_in_trees: bool = False,
stop_on_first_solution: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -88,6 +89,7 @@ def __init__(
self.expand_purchasable_mols = expand_purchasable_mols
self.set_depth = set_depth
self.set_has_solution = set_has_solution
self.stop_on_first_solution = stop_on_first_solution

# Unique nodes
if self.requires_tree and unique_nodes:
Expand Down Expand Up @@ -148,14 +150,25 @@ def _run_from_graph_after_setup(self, graph: GraphType) -> AlgReturnType:
"""Main method for subclasses to override, which forces them to do setup and teardown."""
raise NotImplementedError

def time_limit_reached(self) -> bool:
def should_stop_search(self, graph) -> bool:
"""
Return true if the search time limit has been reached.
"Time" here refers to ANY time metric (e.g. wall clock time, calls to rxn model).
Generic checking function for whether search should stop.
Base implementation checks whether the time limit has been reached
(both wall clock time and calls to the reaction model)
and whether to stop search because a solution was found (only if `stop_on_first_solution is True`).
Importantly, this function does NOT check whether the iteration limit is reached:
this is because an "iteration" means different things for different algorithms.
We recommend putting this check in the main loop of the algorithm.
"""
elapsed_time = (datetime.now() - self._start_time).total_seconds()
return (elapsed_time >= self.time_limit_s) or (
self.reaction_model.num_calls() >= self.limit_reaction_model_calls
elapsed_time = (
datetime.now() - self._start_time
).total_seconds() # NOTE: `self._start_time` is set in `setup`
return (
(elapsed_time >= self.time_limit_s)
or (self.reaction_model.num_calls() >= self.limit_reaction_model_calls)
or (self.stop_on_first_solution and graph.root_node.has_solution)
)

def set_node_values(
Expand Down
2 changes: 1 addition & 1 deletion syntheseus/search/algorithms/best_first/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _run_from_graph_after_setup(self, graph: GraphType) -> int:
# Run search until time limit or queue is empty
step = 0
for step in range(self.limit_iterations):
if self.time_limit_reached() or len(queue) == 0:
if self.should_stop_search(graph) or len(queue) == 0:
break

# Pop node and potentially expand it
Expand Down
2 changes: 1 addition & 1 deletion syntheseus/search/algorithms/breadth_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _run_from_graph_after_setup(self, graph: GraphType) -> int:
queue = collections.deque([node for node in graph._graph.nodes() if not node.is_expanded])
step = 0 # initialize this variable in case loop is not entered
for step in range(self.limit_iterations):
if self.time_limit_reached() or len(queue) == 0:
if self.should_stop_search(graph) or len(queue) == 0:
break

# Pop node and potentially expand it
Expand Down
2 changes: 1 addition & 1 deletion syntheseus/search/algorithms/mcts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _run_from_graph_after_setup(self, graph: GraphType) -> int:
# Run search until time limit or queue is empty
step = 0 # define explicitly to handle 0 iteration edge case
for step in range(self.limit_iterations):
if self.time_limit_reached():
if self.should_stop_search(graph):
break

# Visit root node
Expand Down
22 changes: 22 additions & 0 deletions syntheseus/tests/search/algorithms/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def _run_alg_and_extract_routes(
time_limit_s: float,
limit_iterations: int = 10_000,
max_routes: int = 100,
**kwargs,
) -> list[SynthesisGraph]:
"""Utility function to run an algorithm and extract routes."""

Expand All @@ -456,6 +457,7 @@ def _run_alg_and_extract_routes(
mol_inventory=task.inventory,
limit_iterations=limit_iterations,
time_limit_s=time_limit_s,
**kwargs,
)
output_graph, _ = alg.run_from_mol(task.target_mol)

Expand Down Expand Up @@ -499,3 +501,23 @@ def test_found_routes2(self, retrosynthesis_task2: RetrosynthesisTask) -> None:
for incorrect_route in retrosynthesis_task2.incorrect_routes.values():
route_matches = [incorrect_route == r for r in route_objs]
assert not any(route_matches)

def test_stop_on_first_solution(self, retrosynthesis_task1: RetrosynthesisTask) -> None:
"""
Test that `stop_on_first_solution` really does stop the algorithm once a solution is found.
The test for this is to run the same search as in `test_found_routes1` but with
`stop_on_first_solution=True`. This should find exactly one route for this problem.
Note however that `stop_on_first_solution=True` does not guarantee finding at most one route
because several routes could possibly be found at the same time. The test works for this specific
problem because there is only one route found in the first iteration.
"""

route_objs = self._run_alg_and_extract_routes(
retrosynthesis_task1,
time_limit_s=0.1,
limit_iterations=10_000,
stop_on_first_solution=True,
)
assert len(route_objs) == 1

0 comments on commit 6465917

Please sign in to comment.