Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes (WIP) #452

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from __future__ import annotations

from math import sqrt
from typing import Callable
from typing import TYPE_CHECKING, Callable

import cloudpickle
import numpy as np

from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews
from adaptive.types import Float, Int, Real
from adaptive.utils import (
assign_defaults,
cache_latest,
partial_function_from_dataframe,
)

if TYPE_CHECKING:
from adaptive.types import Float, Int, Real

try:
import pandas
import pandas as pd

with_pandas = True

Expand Down Expand Up @@ -47,6 +49,7 @@ class AverageLearner(BaseLearner):
Points that still have to be evaluated.
npoints : int
Number of evaluated points.

"""

def __init__(
Expand All @@ -57,7 +60,8 @@ def __init__(
min_npoints: int = 2,
) -> None:
if atol is None and rtol is None:
raise Exception("At least one of `atol` and `rtol` should be set.")
msg = "At least one of `atol` and `rtol` should be set."
raise Exception(msg)
if atol is None:
atol = np.inf
if rtol is None:
Expand Down Expand Up @@ -92,7 +96,7 @@ def to_dataframe( # type: ignore[override]
function_prefix: str = "function.",
seed_name: str = "seed",
y_name: str = "y",
) -> pandas.DataFrame:
) -> pd.DataFrame:
"""Return the data as a `pandas.DataFrame`.

Parameters
Expand All @@ -116,10 +120,12 @@ def to_dataframe( # type: ignore[override]
------
ImportError
If `pandas` is not installed.

"""
if not with_pandas:
raise ImportError("pandas is not installed.")
df = pandas.DataFrame(sorted(self.data.items()), columns=[seed_name, y_name])
msg = "pandas is not installed."
raise ImportError(msg)
df = pd.DataFrame(sorted(self.data.items()), columns=[seed_name, y_name])
df.attrs["inputs"] = [seed_name]
df.attrs["output"] = y_name
if with_default_function_args:
Expand All @@ -128,12 +134,12 @@ def to_dataframe( # type: ignore[override]

def load_dataframe( # type: ignore[override]
self,
df: pandas.DataFrame,
df: pd.DataFrame,
with_default_function_args: bool = True,
function_prefix: str = "function.",
seed_name: str = "seed",
y_name: str = "y",
):
) -> None:
"""Load data from a `pandas.DataFrame`.

If ``with_default_function_args`` is True, then ``learner.function``'s
Expand All @@ -153,11 +159,14 @@ def load_dataframe( # type: ignore[override]
The ``seed_name`` used in ``to_dataframe``, by default "seed"
y_name : str, optional
The ``y_name`` used in ``to_dataframe``, by default "y"

"""
self.tell_many(df[seed_name].values, df[y_name].values)
if with_default_function_args:
self.function = partial_function_from_dataframe(
self.function, df, function_prefix
self.function,
df,
function_prefix,
)

def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]]:
Expand All @@ -168,7 +177,7 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]
points = list(
set(range(self.n_requested + n))
- set(self.data)
- set(self.pending_points)
- set(self.pending_points),
)[:n]

loss_improvements = [self._loss_improvement(n) / n] * n
Expand Down Expand Up @@ -199,7 +208,8 @@ def mean(self) -> Float:
@property
def std(self) -> Float:
"""The corrected sample standard deviation of the values
in `data`."""
in `data`.
"""
n = self.npoints
if n < self.min_npoints:
return np.inf
Expand All @@ -211,10 +221,7 @@ def std(self) -> Float:

@cache_latest
def loss(self, real: bool = True, *, n=None) -> Float:
if n is None:
n = self.npoints if real else self.n_requested
else:
n = n
n = (self.npoints if real else self.n_requested) if n is None else n
if n < self.min_npoints:
return np.inf
standard_error = self.std / sqrt(n)
Expand All @@ -232,7 +239,7 @@ def _loss_improvement(self, n: int) -> Float:
else:
return np.inf

def remove_unfinished(self):
def remove_unfinished(self) -> None:
"""Remove uncomputed data from the learner."""
self.pending_points = set()

Expand All @@ -242,7 +249,9 @@ def plot(self):
Returns
-------
holoviews.element.Histogram
A histogram of the evaluated data."""
A histogram of the evaluated data.

"""
hv = ensure_holoviews()
vals = [v for v in self.data.values() if v is not None]
if not vals:
Expand Down
Loading
Loading