diff --git a/README.md b/README.md index 730d63d..7287b17 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,7 @@ To learn more, dive into the following resources: * [Option for customising task execution backends](https://ben-denham.github.io/labtech/runners) * [Diagramming tools](https://ben-denham.github.io/labtech/diagram) * [Distributing across multiple machines](https://ben-denham.github.io/labtech/distributed) +* [Extensible parameter types](https://ben-denham.github.io/labtech/params) * [More examples](https://github.com/ben-denham/labtech/tree/main/examples) diff --git a/docs/cookbook.md b/docs/cookbook.md index 20688cd..7d421e4 100644 --- a/docs/cookbook.md +++ b/docs/cookbook.md @@ -76,6 +76,7 @@ object as a parameter to a task: * Constructing the object in a dependent task * Passing the object in an `Enum` parameter * Passing the object in the lab context +* Defining a custom parameter handler #### Constructing objects in dependent tasks @@ -285,6 +286,78 @@ lab = labtech.Lab( results = lab.run_tasks(experiments) ``` +#### Defining a custom parameter handler + +Advanced users may want to extend Labtech to support additional types +of parameters. To do so, you can declare a custom parameter type +handler class with the [`@param_handler`](/params) decorator. + +The example below demonstrates defining a handler for [Scipy +probability distributions](https://docs.scipy.org/doc/scipy/reference/stats.html): + +``` {.python .code} +import scipy.stats +from scipy.stats.distributions import rv_frozen + + +@labtech.param_handler +class DistributionParamHandler: + """ + There are two important limitations to this implementation: + + 1. Distributions with complex arguments are not supported + (e.g. rv_histogram, which takes arrays as arguments). + 2. Equivalent distributions expressed with different arguments + (e.g. positional vs keyword arguments) will be treated as + different parameter values for the purposes of caching. + + """ + + def handles(self, value): + return isinstance(value, rv_frozen) + + def find_tasks(self, value, *, find_tasks_in_param): + return [] + + def serialize(self, value, *, serializer): + return { + 'name': value.dist.name, + 'args': [serializer.serialize_value(arg) for arg in value.args], + 'kwds': { + key: serializer.serialize_value(kwd) for key, kwd in + sorted(value.kwds.items(), key=lambda pair: pair[0]) + }, + } + + def deserialize(self, serialized, *, serializer): + dist_cls = getattr(scipy.stats, serialized['name']) + args = [serializer.deserialize_value(arg) for arg in serialized['args']] + kwds = { + key: serializer.deserialize_value(kwd) + for key, kwd in serialized['kwds'] + } + return dist_cls(*args, **kwds) + + +@labtech.task +class Experiment: + distribution: rv_frozen + + def run(self): + return self.distribution.mean() + + +experiments = [ + Experiment(distribution=distribution) + for distribution in [ + scipy.stats.norm(loc=42), + scipy.stats.expon(loc=2), + ] +] +lab = labtech.Lab(storage=None) +results = lab.run_tasks(experiments) +``` + ### How can I control multi-processing myself within a task? diff --git a/docs/params.md b/docs/params.md new file mode 100644 index 0000000..b2cd3ac --- /dev/null +++ b/docs/params.md @@ -0,0 +1,11 @@ +To extend the types of data that can be used as Labtech parameters, +you can define a class that implements the +[`ParamHandler`][labtech.types.ParamHandler] protocol and decorate it +with the [`@param_handler`][labtech.param_handler] decorator. A full +example is given [in the cookbook](/cookbook#defining-a-custom-parameter-handler). + +::: labtech.param_handler + +::: labtech.types.ParamHandler + +::: labtech.types.Serializer diff --git a/docs/runners.md b/docs/runners.md index 68af181..a7610e9 100644 --- a/docs/runners.md +++ b/docs/runners.md @@ -36,12 +36,12 @@ See: [Multi-Machine Clusters](./distributed.md) You can define your own Runner Backend to execute tasks with a different form of parallelism or distributed computing platform by defining an implementation of the -[`RunnerBackend`][labtech.types.RunnerBackend] abstract base class: +[`RunnerBackend`][labtech.runners.RunnerBackend] abstract base class: -::: labtech.types.RunnerBackend +::: labtech.runners.RunnerBackend options: heading_level: 4 -::: labtech.types.Runner +::: labtech.runners.Runner options: heading_level: 4 diff --git a/labtech/__init__.py b/labtech/__init__.py index 68e04dd..acd5655 100644 --- a/labtech/__init__.py +++ b/labtech/__init__.py @@ -34,6 +34,7 @@ def run(self): __version__ = '0.7.1' from .lab import Lab +from .params import param_handler from .tasks import task from .types import is_task, is_task_type from .utils import logger @@ -41,6 +42,7 @@ def run(self): __all__ = [ 'is_task_type', 'is_task', + 'param_handler', 'task', 'Lab', 'logger', diff --git a/labtech/cache.py b/labtech/cache.py index b685ca5..9fe264e 100644 --- a/labtech/cache.py +++ b/labtech/cache.py @@ -10,8 +10,8 @@ from . import __version__ as labtech_version from .exceptions import CacheError, TaskNotFound -from .serialization import Serializer -from .types import Cache, ResultMeta, ResultT, Storage, Task, TaskResult, TaskT +from .serialization import DefaultSerializer +from .types import Cache, ResultMeta, ResultT, Serializer, Storage, Task, TaskResult, TaskT class NullCache(Cache): @@ -50,7 +50,7 @@ class BaseCache(Cache): METADATA_FILENAME = 'metadata.json' def __init__(self, *, serializer: Optional[Serializer] = None): - self.serializer = serializer or Serializer() + self.serializer = serializer or DefaultSerializer() def cache_key(self, task: Task) -> str: serialized_str = json.dumps(self.serializer.serialize_task(task)).encode('utf-8') diff --git a/labtech/exceptions.py b/labtech/exceptions.py index b67be65..987a69a 100644 --- a/labtech/exceptions.py +++ b/labtech/exceptions.py @@ -24,6 +24,15 @@ class TaskError(LabtechError): """Raised for failures when handling Task objects.""" +class ParamHandlerError(LabtechError): + """Raised for failures in custom parameter handlers.""" + + +class UnregisteredParamHandlerError(LabtechError): + """Raised when attempting to lookup a custom parameter handler + that is not registered.""" + + class StorageError(LabtechError): """Raised for failures when interacting with Storage objects.""" diff --git a/labtech/lab.py b/labtech/lab.py index 477252d..813119e 100644 --- a/labtech/lab.py +++ b/labtech/lab.py @@ -11,10 +11,11 @@ from .exceptions import LabError, TaskNotFound from .monitor import TaskMonitor -from .runners import ForkRunnerBackend, SerialRunnerBackend, SpawnRunnerBackend, ThreadRunnerBackend +from .params import ParamHandlerManager +from .runners import ForkRunnerBackend, RunnerBackend, SerialRunnerBackend, SpawnRunnerBackend, ThreadRunnerBackend from .storage import LocalStorage, NullStorage from .tasks import get_direct_dependencies -from .types import LabContext, ResultMeta, ResultT, RunnerBackend, Storage, Task, TaskT, is_task, is_task_type +from .types import LabContext, ResultMeta, ResultT, Storage, Task, TaskT, is_task, is_task_type from .utils import OrderedSet, base_tqdm, is_ipython, logger, tqdm, tqdm_notebook @@ -201,13 +202,14 @@ def run(self, tasks: Sequence[Task]) -> dict[Task, Any]: runner = self.lab.runner_backend.build_runner( context=self.lab.context, max_workers=self.lab.max_workers, + param_handler_manager=ParamHandlerManager.get(), storage=self.lab._storage, ) task_monitor = None if not self.disable_top: task_monitor = TaskMonitor( - runner=runner, + get_task_infos=runner.get_task_infos, top_format=self.top_format, top_sort=self.top_sort, top_n=self.top_n, @@ -359,7 +361,7 @@ def __init__(self, *, useful when troubleshooting issues running tasks on different threads and processes. * Any instance of a - [`RunnerBackend`][labtech.types.RunnerBackend], + [`RunnerBackend`][labtech.runners.RunnerBackend], allowing for custom task management implementations. For details on the differences between `'fork'` and diff --git a/labtech/monitor.py b/labtech/monitor.py index 6f52e16..3c66136 100644 --- a/labtech/monitor.py +++ b/labtech/monitor.py @@ -2,12 +2,12 @@ from datetime import datetime from itertools import zip_longest from string import Template -from typing import Optional, Sequence, cast +from typing import Callable, Optional, Sequence, cast import psutil from .exceptions import LabError -from .types import Runner, TaskMonitorInfo, TaskMonitorInfoItem, TaskMonitorInfoValue +from .types import TaskMonitorInfo, TaskMonitorInfoItem, TaskMonitorInfoValue from .utils import tqdm @@ -82,9 +82,9 @@ def show(self) -> None: class TaskMonitor: - def __init__(self, *, runner: Runner, notebook: bool, + def __init__(self, *, get_task_infos: Callable[[], Sequence[TaskMonitorInfo]], notebook: bool, top_format: str, top_sort: str, top_n: int): - self.runner = runner + self.get_task_infos = get_task_infos self.top_template = Template(top_format) self.top_sort = top_sort self.top_sort_key = top_sort @@ -103,7 +103,7 @@ def _top_task_lines(self) -> tuple[int, list[str]]: # Make (shallow) copies of dictionaries to avoid mutating # original dictionaries provided by runner. info.copy() - for info in self.runner.get_task_infos() + for info in self.get_task_infos() ] total_task_count = len(task_infos) diff --git a/labtech/params.py b/labtech/params.py new file mode 100644 index 0000000..4ef2657 --- /dev/null +++ b/labtech/params.py @@ -0,0 +1,105 @@ +#from functools import cached_property +from inspect import isclass +from typing import Optional, Type, TypedDict + +from .exceptions import ParamHandlerError, UnregisteredParamHandlerError +from .types import ParamHandler +from .utils import fully_qualified_class_name + + +class ParamHandlerEntry(TypedDict): + handler: ParamHandler + priority: int + + +class ParamHandlerManager: + + def __init__(self) -> None: + self._entries: dict[str, ParamHandlerEntry] = {} + self._prioritised_handlers: Optional[list[ParamHandler]] = None + + def register(self, cls: Type[ParamHandler], *, priority: int) -> None: + if not isinstance(cls, ParamHandler): + raise ParamHandlerError( + (f"Cannot register '{cls.__qualname__}' as a custom parameter handler, " + "as it does not implement all methods of the 'ParamHandler' protocol.") + ) + + self._entries[fully_qualified_class_name(cls)] = ParamHandlerEntry( + handler=cls(), + priority=priority, + ) + # Clear cache + self._prioritised_handlers = None + + def lookup(self, fq_class_name: str) -> ParamHandler: + try: + entry = self._entries[fq_class_name] + except KeyError: + raise UnregisteredParamHandlerError(fully_qualified_class_name) + return entry['handler'] + + def clear(self) -> None: + self._entries = {} + # Clear cache + self._prioritised_handlers = None + + @property + def prioritised_handlers(self) -> list[ParamHandler]: + if self._prioritised_handlers is None: + self._prioritised_handlers = [ + entry['handler'] for entry in + # Sort param handlers by priority, keeping insertion order + # where priorities are equal. + sorted(self._entries.values(), key=lambda entry: entry['priority']) + ] + return self._prioritised_handlers + + def instantiate(self) -> None: + global _PARAM_HANDLER_MANAGER + _PARAM_HANDLER_MANAGER = self + + @staticmethod + def get() -> 'ParamHandlerManager': + return _PARAM_HANDLER_MANAGER + + +_PARAM_HANDLER_MANAGER = ParamHandlerManager() + + +def param_handler(*args, priority: int = 1000): + """Class decorator for declaring custom parameter handlers that + can define how Labtech should handle the processing, + serialization, and deserialization of additional parameter types. + + Defining a custom parameter handler is an advanced feature of + Labtech, and you are responsible for ensuring: + + * The decorated class implements all methods of the + [`ParamHandler`][labtech.types.ParamHandler] protocol. + * To ensure tasks are reproducible, you should only define + handlers for custom parameter types that are **immutable and + composed only of immutable elements**. + * Because tasks are hashable representations of their parameters, + you should only define handlers for custom parameter types that + are **hashable and composed only of hashable elements**. + * Because serialized parameters will reference the module path and + class name of the custom parameter handler that was used to + serialize them, you should avoid moving or renaming custom + parameter handlers once they are in use. + + Args: + priority: Determines the order in which custom parameter handlers are + applied when processing a parameter value. Lower priority values + are applied first. + + """ + + def decorator(cls): + ParamHandlerManager.get().register(cls, priority=priority) + return cls + + if len(args) > 0 and isclass(args[0]): + return decorator(args[0], *args[1:]) + else: + return decorator diff --git a/labtech/runners/__init__.py b/labtech/runners/__init__.py index 8b7fa77..955898a 100644 --- a/labtech/runners/__init__.py +++ b/labtech/runners/__init__.py @@ -1,8 +1,11 @@ +from .base import Runner, RunnerBackend from .process import ForkRunnerBackend, SpawnRunnerBackend from .serial import SerialRunnerBackend from .thread import ThreadRunnerBackend __all__ = [ + 'Runner', + 'RunnerBackend', 'ForkRunnerBackend', 'SpawnRunnerBackend', 'SerialRunnerBackend', diff --git a/labtech/runners/base.py b/labtech/runners/base.py index ba2e257..08722c1 100644 --- a/labtech/runners/base.py +++ b/labtech/runners/base.py @@ -1,17 +1,142 @@ +from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import fields from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Iterator, Optional, Sequence from frozendict import frozendict from labtech.exceptions import LabError +from labtech.params import ParamHandlerManager from labtech.tasks import is_task -from labtech.types import LabContext, ResultMeta, Storage, Task, TaskResult +from labtech.types import LabContext, ResultMeta, Storage, Task, TaskMonitorInfo, TaskResult from labtech.utils import logger +class Runner(ABC): + """Manages the execution of [Tasks][labtech.types.Task], typically + by delegating to a parallel processing framework.""" + + @abstractmethod + def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: + """Submit the given task object to be run and have its result cached. + + It is up to the Runner to decide when to start running the + task (i.e. when resources become available). + + The implementation of this method should run the task by + effectively calling: + + ``` + # param_handler_manager needs to be instantiated in remote processes + # that don't inherit from the main process: + param_handler_manager.instantiate() + + for dependency_task in get_direct_dependencies(task, all_identities=True): + # Where results_map is expected to contain the TaskResult for + # each dependency_task. + dependency_task._set_results_map(results_map) + + current_process = multiprocessing.current_process() + orig_process_name = current_process.name + try: + # If the thread name or similar is set instead of the process + # name, then the Runner should update the handler of the global + # labtech.utils.logger to include that instead of the process name. + current_process.name = task_name + return labtech.runners.base.run_or_load_task( + task=task, + use_cache=use_cache, + filtered_context=task.filter_context(context), + storage=storage, + ) + finally: + current_process.name = orig_process_name + ``` + + Args: + task: The task to execute. + task_name: Name to use when referring to the task in logs. + use_cache: If True, the task's result should be fetched from the + cache if it is available (fetching should still be done in a + delegated process). + + """ + + @abstractmethod + def wait(self, *, timeout_seconds: Optional[float]) -> Iterator[tuple[Task, ResultMeta | BaseException]]: + """Wait up to timeout_seconds or until at least one of the + submitted tasks is done, then return an iterator of tasks in a + done state and a list of tasks in all other states. + + Each task is returned as a pair where the first value is the + task itself, and the second value is either: + + * For a successfully completed task: Metadata of the result. + * For a task that fails with any BaseException descendant: The exception + that was raised. + + Cancelled tasks are never returned. + + """ + + @abstractmethod + def cancel(self) -> None: + """Cancel all submitted tasks that have not yet been started.""" + + @abstractmethod + def stop(self) -> None: + """Stop all currently running tasks.""" + + @abstractmethod + def close(self) -> None: + """Clean up any resources used by the Runner after all tasks + are finished, cancelled, or stopped.""" + + @abstractmethod + def pending_task_count(self) -> int: + """Returns the number of tasks that have been submitted but + not yet cancelled or returned from a call to wait().""" + + @abstractmethod + def get_result(self, task: Task) -> TaskResult: + """Returns the in-memory result for a task that was + successfully run by this Runner. Raises a KeyError for a + result with no in-memory result.""" + + @abstractmethod + def remove_results(self, tasks: Sequence[Task]) -> None: + """Removes the in-memory results for tasks that were + sucessfully run by this Runner. Ignores tasks that have no + in-memory result.""" + + @abstractmethod + def get_task_infos(self) -> list[TaskMonitorInfo]: + """Returns a snapshot of monitoring information about each + task that is currently running.""" + + +class RunnerBackend(ABC): + """Factory class to construct [Runner][labtech.runners.Runner] objects.""" + + @abstractmethod + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, + max_workers: Optional[int]) -> Runner: + """Return a Runner prepared with the given configuration. + + Args: + context: Additional variables made available to tasks that aren't + considered when saving to/loading from the cache. + storage: Where task results should be cached to. + param_handler_manager: Custom parameter handling configuration + to be instantiated on remote processes. + max_workers: The maximum number of parallel worker processes for + running tasks. + """ + + @contextmanager def optional_mlflow(task: Task): """Context manager to set mlflow "run" configuration for a task if diff --git a/labtech/runners/process.py b/labtech/runners/process.py index 0be4f7b..53ff3f6 100644 --- a/labtech/runners/process.py +++ b/labtech/runners/process.py @@ -18,11 +18,12 @@ from labtech.exceptions import RunnerError, TaskDiedError from labtech.monitor import get_process_info +from labtech.params import ParamHandlerManager from labtech.tasks import get_direct_dependencies -from labtech.types import LabContext, ResultMeta, ResultsMap, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult +from labtech.types import LabContext, ResultMeta, ResultsMap, Storage, Task, TaskMonitorInfo, TaskResult from labtech.utils import LoggerFileProxy, logger -from .base import run_or_load_task +from .base import Runner, RunnerBackend, run_or_load_task class FutureStateError(Exception): @@ -479,7 +480,8 @@ class SpawnRunnerBackend(RunnerBackend): """ - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> SpawnProcessRunner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> SpawnProcessRunner: return SpawnProcessRunner( context=context, storage=storage, @@ -557,7 +559,8 @@ class ForkRunnerBackend(RunnerBackend): """ - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> ForkProcessRunner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> ForkProcessRunner: return ForkProcessRunner( context=context, storage=storage, diff --git a/labtech/runners/ray.py b/labtech/runners/ray.py index 4cf8170..cd53089 100644 --- a/labtech/runners/ray.py +++ b/labtech/runners/ray.py @@ -6,11 +6,12 @@ from typing import Iterator, Optional, Sequence from labtech.exceptions import RunnerError +from labtech.params import ParamHandlerManager from labtech.tasks import get_direct_dependencies -from labtech.types import LabContext, ResultMeta, ResultT, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult, is_task +from labtech.types import LabContext, ResultMeta, ResultT, Storage, Task, TaskMonitorInfo, TaskResult, is_task from labtech.utils import logger -from .base import run_or_load_task +from .base import Runner, RunnerBackend, run_or_load_task try: import ray @@ -31,7 +32,10 @@ class TaskDetail: # arguments, even though they work. @ray.remote(num_returns=2) # type: ignore[arg-type] def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: bool, - context: LabContext, storage: Storage) -> tuple[ResultMeta, ResultT]: + context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager) -> tuple[ResultMeta, ResultT]: + param_handler_manager.instantiate() + # task_refs_args is expected to be a flattened list of (task, # result_meta, result_value) triples - passed this way to ensure # refs are top-level to trigger locality-aware scheduling: @@ -72,7 +76,7 @@ def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: b class RayRunner(Runner): - def __init__(self, *, context: LabContext, storage: Storage, + def __init__(self, *, context: LabContext, storage: Storage, param_handler_manager: ParamHandlerManager, monitor_interval_seconds: float, monitor_timeout_seconds: int) -> None: self.monitor_interval_seconds = monitor_interval_seconds self.monitor_timeout_seconds = monitor_timeout_seconds @@ -83,6 +87,7 @@ def __init__(self, *, context: LabContext, storage: Storage, logger.debug('Uploading context and storage objects to ray object store') self.context_ref = ray.put(context) self.storage_ref = ray.put(storage) + self.param_handler_manager_ref = ray.put(param_handler_manager) logger.debug('Uploaded context and storage objects to ray object store') self.cancelled = False @@ -137,6 +142,7 @@ def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: use_cache=use_cache, context=self.context_ref, storage=self.storage_ref, + param_handler_manager=self.param_handler_manager_ref, ) ) result_meta_ref, result_value_ref = result_refs @@ -313,7 +319,8 @@ def __init__(self, monitor_interval_seconds: float = 1, monitor_timeout_seconds: self.monitor_interval_seconds = monitor_interval_seconds self.monitor_timeout_seconds = monitor_timeout_seconds - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> Runner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> Runner: if max_workers is not None: raise RunnerError(( 'Remove max_workers from your Lab configuration, as RayRunnerBackend only supports max_workers=None. ' @@ -324,6 +331,7 @@ def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Op return RayRunner( context=context, storage=storage, + param_handler_manager=param_handler_manager, monitor_interval_seconds=self.monitor_interval_seconds, monitor_timeout_seconds=self.monitor_timeout_seconds, ) diff --git a/labtech/runners/serial.py b/labtech/runners/serial.py index c6c2f6f..b9515f2 100644 --- a/labtech/runners/serial.py +++ b/labtech/runners/serial.py @@ -6,11 +6,12 @@ import psutil from labtech.monitor import get_process_info +from labtech.params import ParamHandlerManager from labtech.tasks import get_direct_dependencies -from labtech.types import LabContext, ResultMeta, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult +from labtech.types import LabContext, ResultMeta, Storage, Task, TaskMonitorInfo, TaskResult from labtech.utils import logger -from .base import run_or_load_task +from .base import Runner, RunnerBackend, run_or_load_task @dataclass(frozen=True) @@ -114,7 +115,8 @@ class SerialRunnerBackend(RunnerBackend): """Runner Backend that runs each task serially in the main process and thread.""" - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> SerialRunner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> SerialRunner: return SerialRunner( context=context, storage=storage, diff --git a/labtech/runners/thread.py b/labtech/runners/thread.py index ccf5eba..699fd19 100644 --- a/labtech/runners/thread.py +++ b/labtech/runners/thread.py @@ -9,11 +9,12 @@ from labtech.exceptions import RunnerError from labtech.monitor import get_process_info +from labtech.params import ParamHandlerManager from labtech.tasks import get_direct_dependencies -from labtech.types import LabContext, ResultMeta, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult +from labtech.types import LabContext, ResultMeta, Storage, Task, TaskMonitorInfo, TaskResult from labtech.utils import OrderedSet, logger, make_logger_handler -from .base import run_or_load_task +from .base import Runner, RunnerBackend, run_or_load_task class KillThread(Exception): @@ -171,7 +172,8 @@ class ThreadRunnerBackend(RunnerBackend): """ - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> ThreadRunner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> ThreadRunner: return ThreadRunner( context=context, storage=storage, diff --git a/labtech/serialization.py b/labtech/serialization.py index c907f05..66ea59d 100644 --- a/labtech/serialization.py +++ b/labtech/serialization.py @@ -2,23 +2,61 @@ from dataclasses import fields from enum import Enum -from typing import Optional, Type, Union, cast +from typing import Any, Optional, Type, cast from frozendict import frozendict -from .exceptions import SerializationError -from .types import ResultMeta, Task, is_task -from .utils import ensure_dict_key_str +from .exceptions import SerializationError, UnregisteredParamHandlerError +from .params import ParamHandlerManager +from .types import ParamHandler, ResultMeta, Serializer, Task, is_task, jsonable +from .utils import ensure_dict_key_str, fully_qualified_class_name -# Type to represent any value that can be handled by Python's default -# json encoder and decoder. -jsonable = Union[None, str, bool, float, int, - dict[str, 'jsonable'], list['jsonable']] +class DefaultSerializer(Serializer): + """Default Serializer implementation.""" -class Serializer: + def _is_serialized_custom(self, serialized: jsonable) -> bool: + return isinstance(serialized, dict) and bool(serialized.get('_is_custom', False)) - def is_serialized_task(self, serialized: jsonable) -> bool: + def _serialize_custom(self, custom_param_handler: ParamHandler, value: Any) -> dict[str, jsonable]: + return { + '_is_custom': True, + '__class__': self.serialize_class(custom_param_handler.__class__), + 'value': custom_param_handler.serialize( + value=value, + serializer=self, + ), + } + + def _deserialize_custom(self, serialized: dict[str, jsonable]) -> Any: + if not self._is_serialized_custom(serialized): + raise SerializationError(("deserialize_custom() must be called with a " + f"serialized custom value, received: '{serialized}'")) + + try: + custom_param_handler = ParamHandlerManager.get().lookup(cast(str, serialized['__class__'])) + except UnregisteredParamHandlerError: + custom_param_handler = self.deserialize_class(serialized['__class__'])() + return custom_param_handler.deserialize( + serialized=serialized['value'], + serializer=self, + ) + + def _is_serialized_enum(self, serialized: jsonable) -> bool: + return isinstance(serialized, dict) and bool(serialized.get('_is_enum', False)) + + def _serialize_enum(self, value: Enum) -> jsonable: + return { + '_is_enum': True, + '__class__': self.serialize_class(value.__class__), + 'name': value.name, + } + + def _deserialize_enum(self, serialized: dict[str, jsonable]) -> Enum: + enum_cls = self.deserialize_class(serialized['__class__']) + return enum_cls[serialized['name']] + + def _is_serialized_task(self, serialized: jsonable) -> bool: return isinstance(serialized, dict) and bool(serialized.get('_is_task', False)) def serialize_task(self, task: Task) -> dict[str, jsonable]: @@ -44,7 +82,7 @@ def serialize_task(self, task: Task) -> dict[str, jsonable]: return serialized def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Optional[ResultMeta]) -> Task: - if not self.is_serialized_task(serialized): + if not self._is_serialized_task(serialized): raise SerializationError(("deserialize_task() must be called with a " f"serialized Task, received: '{serialized}'")) @@ -69,18 +107,22 @@ def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Opti task._set_result_meta(result_meta) return task - def serialize_value(self, value) -> jsonable: + def serialize_value(self, value: Any) -> jsonable: + for custom_param_handler in ParamHandlerManager.get().prioritised_handlers: + if custom_param_handler.handles(value): + return self._serialize_custom(custom_param_handler, value) + if is_task(value): return self.serialize_task(value) elif isinstance(value, tuple): return [self.serialize_value(item) for item in value] elif isinstance(value, frozendict): return { - ensure_dict_key_str(key, exception_type=SerializationError): self.serialize_value(value) - for key, value in value.items() + ensure_dict_key_str(k, exception_type=SerializationError): self.serialize_value(v) + for k, v in value.items() } elif isinstance(value, Enum): - return self.serialize_enum(value) + return self._serialize_enum(value) elif ((value is None) or isinstance(value, str) or isinstance(value, bool) @@ -91,7 +133,9 @@ def serialize_value(self, value) -> jsonable: "that your task's parameters only use supported types.")) def deserialize_value(self, value: jsonable): - if self.is_serialized_task(value): + if self._is_serialized_custom(value): + return self._deserialize_custom(cast(dict[str, jsonable], value)) + elif self._is_serialized_task(value): return self.deserialize_task(cast(dict[str, jsonable], value), result_meta=None) elif isinstance(value, list): return tuple([self.deserialize_value(item) for item in value]) @@ -100,26 +144,12 @@ def deserialize_value(self, value: jsonable): ensure_dict_key_str(k, exception_type=SerializationError): self.deserialize_value(v) for k, v in value.items() }) - elif self.is_serialized_enum(value): - return self.deserialize_enum(cast(dict[str, jsonable], value)) + elif self._is_serialized_enum(value): + return self._deserialize_enum(cast(dict[str, jsonable], value)) return value - def is_serialized_enum(self, serialized: jsonable) -> bool: - return isinstance(serialized, dict) and bool(serialized.get('_is_enum', False)) - - def serialize_enum(self, value: Enum) -> jsonable: - return { - '_is_enum': True, - '__class__': self.serialize_class(value.__class__), - 'name': value.name, - } - - def deserialize_enum(self, serialized: dict[str, jsonable]) -> Enum: - enum_cls = self.deserialize_class(serialized['__class__']) - return enum_cls[serialized['name']] - def serialize_class(self, cls: Type) -> jsonable: - return f'{cls.__module__}.{cls.__qualname__}' + return fully_qualified_class_name(cls) def deserialize_class(self, serialized_class: jsonable) -> Type: cls_module, cls_name = cast(str, serialized_class).rsplit('.', 1) diff --git a/labtech/tasks.py b/labtech/tasks.py index cc648d4..f52b9c5 100644 --- a/labtech/tasks.py +++ b/labtech/tasks.py @@ -1,7 +1,9 @@ """Utilities for defining tasks.""" +from collections.abc import Hashable from dataclasses import dataclass, fields from enum import Enum +from functools import partial from inspect import isclass from types import UnionType from typing import Any, Optional, Sequence, TypeAlias, Union, cast @@ -10,6 +12,7 @@ from .cache import NullCache, PickleCache from .exceptions import TaskError +from .params import ParamHandlerManager from .types import Cache, LabContext, ResultMeta, ResultsMap, ResultT, Task, TaskInfo, is_task, is_task_type from .utils import ensure_dict_key_str @@ -32,7 +35,20 @@ class CacheDefault: def immutable_param_value(key: str, value: Any) -> Any: - """Converts a parameter value to an immutable equivalent that is hashable.""" + """Converts a parameter value to an immutable equivalent that is + hashable (so that the task itself is hashable to be stored in + sets).""" + # Any value handled by custom_param_handlers is expected to be + # immutable and hashable. + for custom_param_handler in ParamHandlerManager.get().prioritised_handlers: + if custom_param_handler.handles(value): + if not isinstance(value, Hashable): + raise TaskError( + (f"Type '{type(value).__qualname__}' in parameter value '{key}' is handled " + f"by '{type(custom_param_handler).__qualname__}', but is not hashable.") + ) + return value + if isinstance(value, list) or isinstance(value, tuple): return tuple(immutable_param_value(f'{key}[{i}]', item) for i, item in enumerate(value)) if isinstance(value, dict) or isinstance(value, frozendict): @@ -152,6 +168,9 @@ def task(*args, * Note: Mutable `list` and `dict` collections will be converted to immutable `tuple` and [`frozendict`](https://pypi.org/project/frozendict/) collections. + * Immutable and hashable values for which a + [custom parameter handler][labtech.params.param_handler] has been + registered. The task type is expected to define a `run()` method that takes no arguments (other than `self`). The `run()` method should execute @@ -207,6 +226,10 @@ def run(self): documentation of each runner backend for supported options. The implementation may make use of the task's parameter values. + Because serialized tasks will reference the module path and class + name of the task type, you should avoid moving or renaming task + types once they are in use. + Args: cache: The Cache that controls how task results are formatted for caching. Can be set to an instance of any @@ -289,6 +312,14 @@ def find_tasks_in_param(param_value: Any, searched_coll_ids: Optional[set[int]] if id(param_value) in searched_coll_ids: return [] + for custom_param_handler in ParamHandlerManager.get().prioritised_handlers: + if custom_param_handler.handles(param_value): + searched_coll_ids = searched_coll_ids | {id(param_value)} + return custom_param_handler.find_tasks( + value=param_value, + find_tasks_in_param=partial(find_tasks_in_param, searched_coll_ids=searched_coll_ids), + ) + if is_task(param_value): return [param_value] elif isinstance(param_value, list) or isinstance(param_value, tuple): diff --git a/labtech/types.py b/labtech/types.py index a2cfbcd..19434dc 100644 --- a/labtech/types.py +++ b/labtech/types.py @@ -9,15 +9,21 @@ Any, Callable, Generic, - Iterator, Literal, Optional, Protocol, Sequence, Type, TypeVar, + Union, + runtime_checkable, ) +# Type to represent any value that can be handled by Python's default +# json encoder and decoder. +jsonable = Union[None, str, bool, float, int, + dict[str, 'jsonable'], list['jsonable']] + @dataclass(frozen=True) class TaskInfo: @@ -222,121 +228,77 @@ def delete(self, storage: Storage, task: Task) -> None: `storage`.""" -TaskMonitorInfoValue = datetime | str | int | float -TaskMonitorInfoItem = TaskMonitorInfoValue | tuple[TaskMonitorInfoValue, str] -TaskMonitorInfo = dict[str, TaskMonitorInfoItem] - - -class Runner(ABC): - """Manages the execution of [Tasks][labtech.types.Task], typically - by delegating to a parallel processing framework.""" +class Serializer(ABC): + """Serializer for producing serialized JSON representations of + Task objects, and deserializing JSON back into Task objects.""" @abstractmethod - def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: - """Submit the given task object to be run and have its result cached. - - It is up to the Runner to decide when to start running the - task (i.e. when resources become available). - - The implementation of this method should run the task by - effectively calling: - - ``` - for dependency_task in get_direct_dependencies(task, all_identities=True): - # Where results_map is expected to contain the TaskResult for - # each dependency_task. - dependency_task._set_results_map(results_map) - - current_process = multiprocessing.current_process() - orig_process_name = current_process.name - try: - # If the thread name or similar is set instead of the process - # name, then the Runner should update the handler of the global - # labtech.utils.logger to include that instead of the process name. - current_process.name = task_name - return labtech.runners.base.run_or_load_task( - task=task, - use_cache=use_cache, - filtered_context=task.filter_context(self.context), - storage=self.storage, - ) - finally: - current_process.name = orig_process_name - ``` - - Args: - task: The task to execute. - task_name: Name to use when referring to the task in logs. - use_cache: If True, the task's result should be fetched from the - cache if it is available (fetching should still be done in a - delegated process). - - """ + def serialize_task(self, task: Task) -> dict[str, jsonable]: + """Convert the given task into a JSON-compatible + representation composed only of dictionaries, lists, strings, + numbers and `None`.""" @abstractmethod - def wait(self, *, timeout_seconds: Optional[float]) -> Iterator[tuple[Task, ResultMeta | BaseException]]: - """Wait up to timeout_seconds or until at least one of the - submitted tasks is done, then return an iterator of tasks in a - done state and a list of tasks in all other states. - - Each task is returned as a pair where the first value is the - task itself, and the second value is either: - - * For a successfully completed task: Metadata of the result. - * For a task that fails with any BaseException descendant: The exception - that was raised. - - Cancelled tasks are never returned. - - """ + def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Optional[ResultMeta]) -> Task: + """Convert the given serialized representation returned by + serialize_task() back into the original task.""" @abstractmethod - def cancel(self) -> None: - """Cancel all submitted tasks that have not yet been started.""" + def serialize_value(self, value: Any) -> jsonable: + """Convert the given value into a JSON-compatible + representation composed only of dictionaries, lists, strings, + numbers and `None`.""" @abstractmethod - def stop(self) -> None: - """Stop all currently running tasks.""" + def deserialize_value(self, value: jsonable): + """Convert the given serialized representation returned by + serialize_value() back into the original value.""" @abstractmethod - def close(self) -> None: - """Clean up any resources used by the Runner after all tasks - are finished, cancelled, or stopped.""" + def serialize_class(self, cls: Type) -> jsonable: + """Convert the given class into a string representation.""" @abstractmethod - def pending_task_count(self) -> int: - """Returns the number of tasks that have been submitted but - not yet cancelled or returned from a call to wait().""" + def deserialize_class(self, serialized_class: jsonable) -> Type: + """Load the class named in the given serialized representation + returned by serialize_class().""" - @abstractmethod - def get_result(self, task: Task) -> TaskResult: - """Returns the in-memory result for a task that was - successfully run by this Runner. Raises a KeyError for a - result with no in-memory result.""" - @abstractmethod - def remove_results(self, tasks: Sequence[Task]) -> None: - """Removes the in-memory results for tasks that were - sucessfully run by this Runner. Ignores tasks that have no - in-memory result.""" +@runtime_checkable +class ParamHandler(Protocol): + """Protocol for custom parameter handlers that can define how + Labtech should handle the processing, serialization, and + deserialization of additional parameter types.""" - @abstractmethod - def get_task_infos(self) -> list[TaskMonitorInfo]: - """Returns a snapshot of monitoring information about each - task that is currently running.""" + def handles(self, value: Any) -> bool: + """Returns True if the given parameter value should be handled + by this class.""" + def find_tasks(self, value: Any, *, find_tasks_in_param: Callable[[Any], Sequence[Task]]) -> list[Task]: + """Given a parameter value, return all tasks within it (not + including tasks within those tasks). -class RunnerBackend(ABC): - """Factory class to construct [Runner][labtech.types.Runner] objects.""" + The provided `find_tasks_in_param` should be called to find + tasks in anynested elements within the value.""" - @abstractmethod - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> Runner: - """Return a Runner prepared with the given configuration. + def serialize(self, value: Any, *, serializer: Serializer) -> jsonable: + """Convert the given parameter value into a JSON-compatible + representation composed only of dictionaries, lists, strings, + numbers and `None`. - Args: - context: Additional variables made available to tasks that aren't - considered when saving to/loading from the cache. - storage: Where task results should be cached to. - max_workers: The maximum number of parallel worker processes for - running tasks. - """ + Also receives the full Serializer, which can be used to call + `serializer.serialize_value()` to serialize nested elements + within the value.""" + + def deserialize(self, serialized: jsonable, *, serializer: Serializer) -> Any: + """Convert the given serialized representation returned by + serialize() back into the original parameter value. + + Also receives the full Serializer, which can be used to call + `serializer.deserialize_value()` to deserialize nested elements + within the serialized representation.""" + + +TaskMonitorInfoValue = datetime | str | int | float +TaskMonitorInfoItem = TaskMonitorInfoValue | tuple[TaskMonitorInfoValue, str] +TaskMonitorInfo = dict[str, TaskMonitorInfoItem] diff --git a/labtech/utils.py b/labtech/utils.py index 4203908..6b65bd1 100644 --- a/labtech/utils.py +++ b/labtech/utils.py @@ -117,6 +117,10 @@ def ensure_dict_key_str(value, *, exception_type: Type[Exception]) -> str: return cast(str, value) +def fully_qualified_class_name(cls: Type) -> str: + return f'{cls.__module__}.{cls.__qualname__}' + + def is_ipython() -> bool: return hasattr(builtins, '__IPYTHON__') @@ -132,10 +136,12 @@ class tqdm_notebook(base_tqdm_notebook): __all__ = [ + 'make_logger_handler', 'logger', 'OrderedSet', 'LoggerFileProxy', 'ensure_dict_key_str', + 'fully_qualified_class_name', 'is_ipython', 'tqdm', 'tqdm_notebook', diff --git a/mkdocs.yml b/mkdocs.yml index 0d47de8..bbbdc07 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,6 +10,7 @@ nav: - Multi-Machine Clusters: 'distributed.md' - Diagramming: 'diagram.md' - Caches and Storage: 'caching.md' + - Custom Parameter Handling: 'params.md' plugins: - search - mkdocstrings: diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py index f302337..10de836 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_e2e.py @@ -1,6 +1,7 @@ """Test a set of tasks packed with usage of features end-to-end. Loosely based on tasks from the tutorial.""" +from datetime import datetime from tempfile import TemporaryDirectory from typing import Any, Protocol, TypedDict @@ -8,10 +9,33 @@ import ray import labtech +from labtech.params import ParamHandlerManager from labtech.runners.ray import RayRunnerBackend from labtech.types import Task +class DatetimeParamHandler: + + def handles(self, value): + return isinstance(value, datetime) + + def find_tasks(self, value, *, find_tasks_in_param): + return [] + + def serialize(self, value, *, serializer): + return value.timestamp() + + def deserialize(self, serialized, *, serializer): + return datetime.fromtimestamp(serialized) + + +@pytest.fixture(autouse=True) +def datetime_param_handler(): + labtech.param_handler(DatetimeParamHandler) + yield + ParamHandlerManager.get().clear() + + @labtech.task(cache=None) class ClassifierTask: n_estimators: int @@ -53,6 +77,7 @@ def run(self) -> dict: @labtech.task class WrappingExperiment(ExperimentTask): experiment: ExperimentTask + dt: datetime @property def dataset_key(self): @@ -60,7 +85,8 @@ def dataset_key(self): def run(self) -> dict: return { - 'inner_experiment': self.experiment.result + 'inner_experiment': self.experiment.result, + 'dt': self.dt, } @@ -92,6 +118,7 @@ class Evaluation(TypedDict): def basic_evaluation(context: dict[str, Any]) -> Evaluation: """Evaluation of a standard setup of multiple levels of dependency.""" + now = datetime.now() classifier_tasks = [ ClassifierTask( n_estimators=n_estimators, @@ -109,6 +136,7 @@ def basic_evaluation(context: dict[str, Any]) -> Evaluation: wrapping_experiments = [ WrappingExperiment( experiment=classifier_experiment, + dt=now, ) for classifier_experiment in classifier_experiments ] @@ -125,10 +153,10 @@ def basic_evaluation(context: dict[str, Any]) -> Evaluation: '2': {'dataset': 'aaa', 'classifier': {'n_estimators': 2}}, '3': {'dataset': 'bbb', 'classifier': {'n_estimators': 1}}, '4': {'dataset': 'bbb', 'classifier': {'n_estimators': 2}}, - '5': {'inner_experiment': {'dataset': 'aaa', 'classifier': {'n_estimators': 1}}}, - '6': {'inner_experiment': {'dataset': 'aaa', 'classifier': {'n_estimators': 2}}}, - '7': {'inner_experiment': {'dataset': 'bbb', 'classifier': {'n_estimators': 1}}}, - '8': {'inner_experiment': {'dataset': 'bbb', 'classifier': {'n_estimators': 2}}}, + '5': {'inner_experiment': {'dataset': 'aaa', 'classifier': {'n_estimators': 1}}, 'dt': now}, + '6': {'inner_experiment': {'dataset': 'aaa', 'classifier': {'n_estimators': 2}}, 'dt': now}, + '7': {'inner_experiment': {'dataset': 'bbb', 'classifier': {'n_estimators': 1}}, 'dt': now}, + '8': {'inner_experiment': {'dataset': 'bbb', 'classifier': {'n_estimators': 2}}, 'dt': now}, }, ) @@ -196,6 +224,7 @@ def test_e2e(self, max_workers: int, runner_backend: str, evaluation_key: str, c cached_result = lab.run_task(cached_tasks[0]) assert cached_result == evaluation['expected_result'] + class TestE2ERay: def setup_method(self, method): diff --git a/tests/labtech/test_params.py b/tests/labtech/test_params.py new file mode 100644 index 0000000..3dc4b9f --- /dev/null +++ b/tests/labtech/test_params.py @@ -0,0 +1,90 @@ +import pytest + +import labtech +from labtech.exceptions import ParamHandlerError +from labtech.params import ParamHandlerManager + + +class TestParamHandler: + + def teardown_method(self, method): + ParamHandlerManager.get().clear() + + def test_register(self): + + @labtech.param_handler + class FrozensetParamHandler: + + def handles(self, value): + return isinstance(value, frozenset) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return [serializer.serialize_value(item) for item in sorted(value, key=hash)] + + def deserialize(self, serialized, *, serializer): + return frozenset([serializer.deserialize_value(item) for item in serialized]) + + assert [type(handler) for handler in ParamHandlerManager.get().prioritised_handlers] == [ + FrozensetParamHandler, + ] + + def test_register_priority(self): + + class FrozensetParamHandler: + + def handles(self, value): + return isinstance(value, frozenset) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return [serializer.serialize_value(item) for item in sorted(value, key=hash)] + + def deserialize(self, serialized, *, serializer): + return frozenset([serializer.deserialize_value(item) for item in serialized]) + + @labtech.param_handler(priority=2000) + class FrozensetParamHandlerOne(FrozensetParamHandler): + pass + + @labtech.param_handler + class FrozensetParamHandlerTwo(FrozensetParamHandler): + pass + + @labtech.param_handler + class FrozensetParamHandlerThree(FrozensetParamHandler): + pass + + @labtech.param_handler(priority=100) + class FrozensetParamHandlerFour(FrozensetParamHandler): + pass + + assert [type(handler) for handler in ParamHandlerManager.get().prioritised_handlers] == [ + FrozensetParamHandlerFour, + FrozensetParamHandlerTwo, + FrozensetParamHandlerThree, + FrozensetParamHandlerOne, + ] + + def test_register_noncompliant(self): + with pytest.raises( + ParamHandlerError, match=( + "Cannot register 'TestParamHandler.test_register_noncompliant..CustomParamHandler' " + "as a custom parameter handler, as it does not implement all methods of the 'ParamHandler' protocol." + ), + ): + @labtech.param_handler + class CustomParamHandler: + pass diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index 763fdfc..8d6dfb4 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -9,6 +9,7 @@ import labtech.tasks from labtech.cache import BaseCache, NullCache, PickleCache from labtech.exceptions import TaskError +from labtech.params import ParamHandlerManager from labtech.tasks import _RESERVED_ATTRS, ParamScalar, find_tasks_in_param, immutable_param_value from labtech.types import ResultT, Storage, Task, TaskInfo @@ -262,6 +263,52 @@ def run(self) -> None: class TestImmutableParamValue: + + def setup_method(self, method): + + @labtech.param_handler + class FrozensetParamHandler: + + def handles(self, value): + return isinstance(value, frozenset) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return [serializer.serialize_value(item) for item in sorted(value, key=hash)] + + def deserialize(self, serialized, *, serializer): + return frozenset([serializer.deserialize_value(item) for item in serialized]) + + @labtech.param_handler + class SetParamHandler: + """This is not a valid param handler, because sets are not + hashable.""" + + def handles(self, value): + return isinstance(value, set) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return [serializer.serialize_value(item) for item in sorted(value, key=hash)] + + def deserialize(self, serialized, *, serializer): + return set([serializer.deserialize_value(item) for item in serialized]) + + def teardown_method(self, method): + ParamHandlerManager.get().clear() + def test_empty_list(self) -> None: assert immutable_param_value("hello", []) == () @@ -306,6 +353,20 @@ def test_nested_list_dict(self) -> None: def test_scalar(self, scalar: ParamScalar) -> None: assert immutable_param_value("hello", scalar) is scalar + def test_custom_param(self) -> None: + example_frozenset = frozenset(['one', 2, frozenset([3, 'four'])]) + assert immutable_param_value("hello", example_frozenset) is example_frozenset + + def test_custom_param_unhashable(self, scalar: ParamScalar) -> None: + example_set = set(['one', 2, frozenset([3, 'four'])]) + with pytest.raises( + TaskError, match=( + "Type 'set' in parameter value 'hello' is handled by " + "'TestImmutableParamValue.setup_method..SetParamHandler', but is not hashable." + ) + ): + immutable_param_value("hello", example_set) + def test_unhandled(self) -> None: with pytest.raises( TaskError, match="Unsupported type '_BadObject' in parameter value 'hello'." @@ -321,6 +382,31 @@ def test_multiple_nested_error(self) -> None: class TestFindTasksInParam: + + def setup_method(self, method): + + @labtech.param_handler + class FrozensetParamHandler: + + def handles(self, value): + return isinstance(value, frozenset) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return list(sorted(value, key=hash)) + + def deserialize(self, value, *, serializer): + return frozenset(value) + + def teardown_method(self, method): + ParamHandlerManager.get().clear() + def test_scalar(self, scalar: ParamScalar) -> None: assert find_tasks_in_param(scalar) == [] @@ -363,6 +449,14 @@ def test_searched_coll_ids(self) -> None: task2 ] + def test_custom_param_handler(self) -> None: + task1 = ExampleTask(1) + task2 = ExampleTask(2) + assert find_tasks_in_param(frozenset([1, task1, frozenset([task2, 2])])) == [ + task1, + task2, + ] + def test_unhandled(self) -> None: match = re.escape( "Unexpected type _BadObject encountered in task parameter value."