Skip to content

Commit

Permalink
Make it easier to run search with custom models (#97)
Browse files Browse the repository at this point in the history
Our single-step evaluation CLI defines the evaluation-specific arguments
in `BaseEvalConfig`, and then creates `EvalConfig` by combining
`BaseEvalConfig` and `BackwardModelConfig` as base classes. This design
allows one to inject custom models by defining a modified
`BackwardModelConfig`, then creating a modified `EvalConfig` (by mixing
in `BaseEvalConfig`), and passing that to the syntheseus code (relying
on a bit of duck typing). Internally, this allowed us to easily
benchmark internal models using the shared eval script from syntheseus.
This PR adapts the search CLI to follow a similar convention, so that
one can also easily integrate custom models into search.
  • Loading branch information
kmaziarz authored Aug 14, 2024
1 parent 38f475f commit 4f965f8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
6 changes: 5 additions & 1 deletion syntheseus/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import sys
from typing import Callable, Dict

from syntheseus.cli import eval_single_step, search


def main() -> None:
supported_commands = {"search": search.main, "eval-single-step": eval_single_step.main}
supported_commands: Dict[str, Callable] = {
"search": search.main,
"eval-single-step": eval_single_step.main,
}
supported_command_names = ", ".join(supported_commands.keys())

if len(sys.argv) == 1:
Expand Down
17 changes: 11 additions & 6 deletions syntheseus/cli/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ class PDVNConfig:


@dataclass
class SearchConfig(BackwardModelConfig):
"""Config for running search for given search targets."""

class BaseSearchConfig:
# Molecule(s) to search for (either as a single explicit SMILES or a file)
search_target: str = MISSING
search_targets_file: str = MISSING
Expand Down Expand Up @@ -136,6 +134,13 @@ class SearchConfig(BackwardModelConfig):
num_routes_to_plot: int = 5 # Number of routes to extract and plot for a quick check


@dataclass
class SearchConfig(BackwardModelConfig, BaseSearchConfig):
"""Config for running search for given search targets."""

pass


def run_from_config(config: SearchConfig) -> Path:
set_random_seed(0)

Expand Down Expand Up @@ -343,8 +348,8 @@ def build_node_evaluator(key: str) -> None:
return results_dir_current_run


def main(argv: Optional[List[str]] = None) -> Path:
config: SearchConfig = cli_get_config(argv=argv, config_cls=SearchConfig)
def main(argv: Optional[List[str]] = None, config_cls: Any = SearchConfig) -> Path:
config = cli_get_config(argv=argv, config_cls=config_cls)

def _warn_will_not_use_defaults(message: str) -> None:
logger.warning(f"{message}; no model-specific search hyperparameters will be used")
Expand Down Expand Up @@ -378,7 +383,7 @@ def _warn_will_not_use_defaults(message: str) -> None:
# we did not know the search algorithm and model class before the first parsing).
config = cli_get_config(
argv=argv,
config_cls=SearchConfig,
config_cls=config_cls,
defaults={f"{config.search_algorithm}_config": relevant_defaults},
)

Expand Down

0 comments on commit 4f965f8

Please sign in to comment.