Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type-hints to SKOptLearner #376

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions adaptive/learner/skopt_learner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
from typing import Callable

import numpy as np
from skopt import Optimizer
Expand All @@ -25,8 +26,8 @@ class SKOptLearner(Optimizer, BaseLearner):
Arguments to pass to ``skopt.Optimizer``.
"""

def __init__(self, function, **kwargs):
self.function = function
def __init__(self, function: Callable, **kwargs) -> None:
self.function = function # type: ignore
self.pending_points = set()
self.data = collections.OrderedDict()
self._kwargs = kwargs
Expand All @@ -36,7 +37,7 @@ def new(self) -> SKOptLearner:
"""Return a new `~adaptive.SKOptLearner` without the data."""
return SKOptLearner(self.function, **self._kwargs)

def tell(self, x, y, fit=True):
def tell(self, x: float | list[float], y: float, fit: bool = True) -> None:
if isinstance(x, collections.abc.Iterable):
self.pending_points.discard(tuple(x))
self.data[tuple(x)] = y
Expand All @@ -55,7 +56,7 @@ def remove_unfinished(self):
pass

@cache_latest
def loss(self, real=True):
def loss(self, real: bool = True) -> float:
if not self.models:
return np.inf
else:
Expand All @@ -65,7 +66,12 @@ def loss(self, real=True):
# estimator of loss, but it is the cheapest.
return 1 - model.score(self.Xi, self.yi)

def ask(self, n, tell_pending=True):
def ask(
self, n: int, tell_pending: bool = True
) -> (
tuple[list[float], list[float]]
| tuple[list[list[float]], list[float]] # XXX: this indicates a bug!
):
if not tell_pending:
raise NotImplementedError(
"Asking points is an irreversible "
Expand All @@ -79,7 +85,7 @@ def ask(self, n, tell_pending=True):
return [p[0] for p in points], [self.loss() / n] * n

@property
def npoints(self):
def npoints(self) -> int:
"""Number of evaluated points."""
return len(self.Xi)

Expand Down