Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
ParamOverrides, Undefined, get_logger
)
from .parameterized import (batch_watch, output, script_repr,
discard_events, edit_constant)
discard_events, edit_constant, serializer)
from .parameterized import shared_parameters
from .parameterized import logging_level
from .parameterized import DEBUG, VERBOSE, INFO, WARNING, ERROR, CRITICAL
Expand Down Expand Up @@ -152,8 +152,7 @@

#: Top-level object to allow messaging not tied to a particular
#: Parameterized object, as in 'param.main.warning("Invalid option")'.
main=Parameterized(name="main")

main = Parameterized(name="main")

# A global random seed (integer or rational) available for controlling
# the behaviour of Parameterized objects with random state.
Expand Down
131 changes: 80 additions & 51 deletions param/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import collections
import contextvars
import datetime as dt
import functools
Expand All @@ -10,16 +9,17 @@
import re
import sys
import traceback
import typing as t
import warnings
from collections import OrderedDict, abc, defaultdict
from contextlib import contextmanager
from numbers import Real
from textwrap import dedent
from threading import get_ident
from typing import TYPE_CHECKING, Callable, TypeVar
from typing import TYPE_CHECKING, Callable, Protocol, TypeVar, runtime_checkable

if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec
from param.parameterized import Parameter

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -65,8 +65,8 @@ class Skip(Exception):
"""Exception that allows skipping an update when resolving a reference."""


def _deprecated(extra_msg="", warning_cat=ParamDeprecationWarning):
def decorator(func):
def _deprecated(extra_msg: str = "", warning_cat: type[Warning] = ParamDeprecationWarning):
def decorator(func: Callable[..., t.Any]) -> Callable[..., t.Any]:
"""Mark a function or method as deprecated.

This internal decorator issues a warning when the decorated function
Expand Down Expand Up @@ -96,7 +96,7 @@ def inner(*args, **kwargs):
return decorator


def _deprecate_positional_args(func):
def _deprecate_positional_args(func: Callable[..., t.Any]) -> Callable[..., t.Any]:
"""Issue warnings for methods using deprecated positional arguments.

This internal decorator warns when arguments after the `*` separator
Expand Down Expand Up @@ -171,8 +171,8 @@ def wrapper(self, *args, **kwargs):
return decorating_function


def _is_auto_name(class_name, instance_name):
return re.match('^'+class_name+'[0-9]{5}$', instance_name)
def _is_auto_name(class_name: str, instance_name: str) -> bool:
return bool(re.match(f'^{class_name}[0-9]{5}$', instance_name))


def _find_pname(pclass):
Expand All @@ -182,12 +182,14 @@ def _find_pname(pclass):
"""
stack = traceback.extract_stack()
for frame in stack:
match = re.match(r"^(\S+)\s*=\s*(param|pm)\." + pclass + r"\(", frame.line)
if frame.line is None:
continue
match = re.match(rf"^(\S+)\s*=\s*(param|pm)\.{pclass}\(", frame.line)
if match:
return match.group(1)


def _validate_error_prefix(parameter, attribute=None):
def _validate_error_prefix(parameter: Parameter, attribute: str | None = None) -> str:
"""
Generate an error prefix suitable for Parameters when they raise a validation
error.
Expand Down Expand Up @@ -376,10 +378,10 @@ def _hashable(x):
part of the object has changed. Does not (currently) recursively
replace mutable subobjects.
"""
if isinstance(x, collections.abc.MutableSequence):
if isinstance(x, abc.MutableSequence):
return tuple(x)
elif isinstance(x, collections.abc.MutableMapping):
return tuple([(k,v) for k,v in x.items()])
elif isinstance(x, abc.MutableMapping):
return tuple([(k, v) for k, v in x.items()])
else:
return x

Expand Down Expand Up @@ -446,42 +448,62 @@ def named_objs(objlist, namesdict=None):
"""
return _named_objs(objlist, namesdict=namesdict)

from typing import SupportsFloat


def _get_min_max_value(min, max, value=None, step=None):
def _get_min_max_value(
min: SupportsFloat | None,
max: SupportsFloat | None,
value: SupportsFloat | None = None,
step: SupportsFloat | None = None,
) -> tuple[float, float, float]:
"""Return min, max, value given input values with possible None."""
# Either min and max need to be given, or value needs to be given
fmin = float(min) if min is not None else None
fmax = float(max) if max is not None else None

if value is None:
if min is None or max is None:
raise ValueError(
f'unable to infer range, value from: ({min}, {max}, {value})'
)
diff = max - min
value = min + (diff / 2)
# Ensure that value has the same type as diff
if not isinstance(value, type(diff)):
value = min + (diff // 2)
else: # value is not None
if not isinstance(value, Real):
raise TypeError('expected a real number, got: %r' % value)
# Infer min/max from value
if value == 0:
# This gives (0, 1) of the correct type
vrange = (value, value + 1)
elif value > 0:
vrange = (-value, 3*value)
else:
vrange = (3*value, -value)
if min is None:
min = vrange[0]
if max is None:
max = vrange[1]
if fmin is None or fmax is None:
raise ValueError(f"unable to infer range, value from: ({min}, {max}, {value})")
fvalue = (fmin + fmax) / 2.0
else:
fvalue = float(value)
if fmin is None or fmax is None:
if fvalue == 0.0:
low, high = 0.0, 1.0
elif fvalue > 0.0:
low, high = -fvalue, 3.0 * fvalue
else:
low, high = 3.0 * fvalue, -fvalue
if fmin is None:
fmin = low
if fmax is None:
fmax = high

# Safety: ensure bounds exist
if fmin is None or fmax is None:
raise RuntimeError("internal error: bounds not resolved")

# Normalize so fmin <= fmax
if fmin > fmax:
fmin, fmax = fmax, fmin

# Snap to step if requested
if step is not None:
# ensure value is on a step
tick = int((value - min) / step)
value = min + tick * step
if not min <= value <= max:
raise ValueError(f'value must be between min and max (min={min}, value={value}, max={max})')
return min, max, value
fstep = abs(float(step))
if fstep == 0.0:
raise ValueError("step must be non-zero")
ticks = round((fvalue - fmin) / fstep) # nearest tick; use math.floor for always-down
fvalue = fmin + ticks * fstep
# Clamp after snapping
if fvalue < fmin:
fvalue = fmin
if fvalue > fmax:
fvalue = fmax

if not (fmin <= fvalue <= fmax):
raise ValueError(f"value must be between min and max (min={fmin}, value={fvalue}, max={fmax})")

return fmin, fmax, fvalue


def _deserialize_from_path(ext_to_routine, path, type_name):
Expand Down Expand Up @@ -660,12 +682,13 @@ def exceptions_summarized():
except Exception:
import sys
etype, value, tb = sys.exc_info()
print(f"{etype.__name__}: {value}", file=sys.stderr)
if etype is not None:
print(f"{etype.__name__}: {value}", file=sys.stderr)


def _in_ipython():
try:
get_ipython
get_ipython() # type: ignore[name-defined]
return True
except NameError:
return False
Expand All @@ -685,21 +708,27 @@ def async_executor(func):
else:
event_loop.run_until_complete(func())

@runtime_checkable
class _HasTypes(Protocol):
@classmethod
def types(cls) -> abc.Iterable[type]: ...

class _GeneratorIsMeta(type):
def __instancecheck__(cls, inst):
def __instancecheck__(cls: type[_HasTypes], inst):
return isinstance(inst, tuple(cls.types()))

def __subclasscheck__(cls, sub):
def __subclasscheck__(cls: type[_HasTypes], sub: type) -> bool:
return issubclass(sub, tuple(cls.types()))

def __iter__(cls):
def __iter__(cls: type[_HasTypes]) -> abc.Iterator[type]:
yield from cls.types()

class _GeneratorIs(metaclass=_GeneratorIsMeta):
@classmethod
def __iter__(cls):
def __iter__(cls: type[_HasTypes]) -> abc.Iterator[type]:
yield from cls.types()


def gen_types(gen_func):
"""Decorate a generator function to support type checking.

Expand Down
Loading
Loading