Skip to content

Commit

Permalink
Add an option to set time limit for route extraction (#104)
Browse files Browse the repository at this point in the history
Our route extraction utilities can behave very differently depending on
the structure of the underlying graph, and thus it's hard to predict how
long extracting a given number of routes will take. In some cases, even
the time between yielding one route and the next can be quite
substantial. To allow for more control over route extraction, this PR
adds a `max_time_s` argument to those utilities, which can be used to
cap the total time spent. This time constraint is enforced in every
iteration (as opposed to on every route found), which allows for
following the limit much more closely.
  • Loading branch information
kmaziarz authored Oct 16, 2024
1 parent 78b2caf commit eb2ed8f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.

- Reuse search results when given a partially filled directory ([#98](https://github.com/microsoft/syntheseus/pull/98)) ([@kmaziarz])
- Expose `stop_on_first_solution` as a CLI flag ([#100](https://github.com/microsoft/syntheseus/pull/100)) ([@kmaziarz])
- Add an option to set time limit for route extraction ([#104](https://github.com/microsoft/syntheseus/pull/104)) ([@kmaziarz])
- Extend single-step evaluation with stereo-agnostic results ([#102](https://github.com/microsoft/syntheseus/pull/102)) ([@kmaziarz])

### Fixed
Expand Down
31 changes: 22 additions & 9 deletions syntheseus/search/analysis/route_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import heapq
import math
from collections.abc import Collection, Iterator
from datetime import datetime
from typing import Callable, Optional, TypeVar

from syntheseus.search.graph.and_or import AndNode, OrNode
Expand All @@ -18,6 +19,7 @@ def _iter_top_routes(
cost_fn: Callable[[Collection[NodeType], RetrosynthesisSearchGraph[NodeType]], float],
cost_lower_bound: Callable[[Collection[NodeType], RetrosynthesisSearchGraph[NodeType]], float],
max_routes: int,
max_time_s: float = math.inf,
yield_partial_routes: bool = False,
) -> Iterator[tuple[float, Collection[NodeType]]]:
"""
Expand All @@ -29,14 +31,15 @@ def _iter_top_routes(
this is the case but are not sure in general.
Args:
graph: graph to iterate over. Could be tree, but does not need to be.
graph: Graph to iterate over. Could be tree, but does not need to be.
cost_fn: Gives the cost of a route (specified by the set of nodes).
A cost of inf means the route will not be returned.
cost_lower_bound: A lower bound of the cost. The lower bound means that
if the function is evaluated on a set A, the cost of a set B >= A
will always exceed this lower bound.
This function will always be evaluated on partial routes.
max_routes: Maximum number of routes to return.
max_time_s: Maximum total number of seconds to spend extracting routes.
yield_partial_routes: if True, will yield routes whose leaves
have children in the full graph. This could be useful if, for example,
there are purchasable molecules which have children.
Expand All @@ -54,10 +57,15 @@ def _iter_top_routes(
(-math.inf, False, 0, {graph.root_node}, [graph.root_node])
]
tie_breaker = 1
start_time = datetime.now()

# Do best-first search
num_routes_yielded = 0
while len(queue) > 0 and num_routes_yielded < max_routes:
while (
len(queue) > 0
and num_routes_yielded < max_routes
and (datetime.now() - start_time).total_seconds() < max_time_s
):
# Pop route
cost, is_true_cost, _, partial_route, route_frontier = heapq.heappop(queue)
assert cost < math.inf, "Infinite cost routes should not be in the queue."
Expand Down Expand Up @@ -162,6 +170,7 @@ def _min_route_partial_cost(
def iter_routes_cost_order(
graph: RetrosynthesisSearchGraph,
max_routes: int,
max_time_s: float = math.inf,
stop_cost: Optional[float] = None,
) -> Iterator[Collection[BaseGraphNode]]:
"""
Expand All @@ -172,17 +181,19 @@ def iter_routes_cost_order(
It is also assumed that `node.has_solution` is set beforehand.
Args:
graph: graph whose routes to extract
max_routes: maximum number of routes to yield.
stop_cost: if provided, iterator will terminate once a route of cost
>= stop_cost is encountered
graph: Graph to extract routes from.
max_routes: Maximum number of routes to yield.
max_time_s: Maximum total number of seconds to spend extracting routes.
stop_cost: If provided, iterator will terminate once a route of cost
larger than `stop_cost` is encountered.
"""

for cost, route in _iter_top_routes(
graph=graph,
cost_fn=_min_route_cost,
cost_lower_bound=_min_route_partial_cost,
max_routes=max_routes,
max_time_s=max_time_s,
yield_partial_routes=False,
):
if stop_cost is not None and cost >= stop_cost:
Expand Down Expand Up @@ -211,7 +222,7 @@ def _route_time_partial_cost(nodes, _) -> float:


def iter_routes_time_order(
graph: RetrosynthesisSearchGraph, max_routes: int
graph: RetrosynthesisSearchGraph, max_routes: int, max_time_s: float = math.inf
) -> Iterator[Collection[BaseGraphNode]]:
"""
Iterate over all solved routes from `graph` in the order they were created
Expand All @@ -221,15 +232,17 @@ def iter_routes_time_order(
Creation time is measured by `node.creation_time`.
Args:
graph: graph whose routes to extract
max_routes: maximum number of routes to yield.
graph: Graph to extract routes from.
max_routes: Maximum number of routes to yield.
max_time_s: Maximum total number of seconds to spend extracting routes.
"""

for _, r in _iter_top_routes(
graph=graph,
cost_fn=_route_time_cost,
cost_lower_bound=_route_time_partial_cost,
max_routes=max_routes,
max_time_s=max_time_s,
yield_partial_routes=False,
):
yield r

0 comments on commit eb2ed8f

Please sign in to comment.