Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ To learn more, dive into the following resources:
* [Option for customising task execution backends](https://ben-denham.github.io/labtech/runners)
* [Diagramming tools](https://ben-denham.github.io/labtech/diagram)
* [Distributing across multiple machines](https://ben-denham.github.io/labtech/distributed)
* [Extensible parameter types](https://ben-denham.github.io/labtech/params)
* [More examples](https://github.com/ben-denham/labtech/tree/main/examples)


Expand Down
73 changes: 73 additions & 0 deletions docs/cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ object as a parameter to a task:
* Constructing the object in a dependent task
* Passing the object in an `Enum` parameter
* Passing the object in the lab context
* Defining a custom parameter handler

#### Constructing objects in dependent tasks

Expand Down Expand Up @@ -285,6 +286,78 @@ lab = labtech.Lab(
results = lab.run_tasks(experiments)
```

#### Defining a custom parameter handler

Advanced users may want to extend Labtech to support additional types
of parameters. To do so, you can declare a custom parameter type
handler class with the [`@param_handler`](/params) decorator.

The example below demonstrates defining a handler for [Scipy
probability distributions](https://docs.scipy.org/doc/scipy/reference/stats.html):

``` {.python .code}
import scipy.stats
from scipy.stats.distributions import rv_frozen


@labtech.param_handler
class DistributionParamHandler:
"""
There are two important limitations to this implementation:

1. Distributions with complex arguments are not supported
(e.g. rv_histogram, which takes arrays as arguments).
2. Equivalent distributions expressed with different arguments
(e.g. positional vs keyword arguments) will be treated as
different parameter values for the purposes of caching.

"""

def handles(self, value):
return isinstance(value, rv_frozen)

def find_tasks(self, value, *, find_tasks_in_param):
return []

def serialize(self, value, *, serializer):
return {
'name': value.dist.name,
'args': [serializer.serialize_value(arg) for arg in value.args],
'kwds': {
key: serializer.serialize_value(kwd) for key, kwd in
sorted(value.kwds.items(), key=lambda pair: pair[0])
},
}

def deserialize(self, serialized, *, serializer):
dist_cls = getattr(scipy.stats, serialized['name'])
args = [serializer.deserialize_value(arg) for arg in serialized['args']]
kwds = {
key: serializer.deserialize_value(kwd)
for key, kwd in serialized['kwds']
}
return dist_cls(*args, **kwds)


@labtech.task
class Experiment:
distribution: rv_frozen

def run(self):
return self.distribution.mean()


experiments = [
Experiment(distribution=distribution)
for distribution in [
scipy.stats.norm(loc=42),
scipy.stats.expon(loc=2),
]
]
lab = labtech.Lab(storage=None)
results = lab.run_tasks(experiments)
```


### How can I control multi-processing myself within a task?

Expand Down
11 changes: 11 additions & 0 deletions docs/params.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
To extend the types of data that can be used as Labtech parameters,
you can define a class that implements the
[`ParamHandler`][labtech.types.ParamHandler] protocol and decorate it
with the [`@param_handler`][labtech.param_handler] decorator. A full
example is given [in the cookbook](/cookbook#defining-a-custom-parameter-handler).

::: labtech.param_handler

::: labtech.types.ParamHandler

::: labtech.types.Serializer
6 changes: 3 additions & 3 deletions docs/runners.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ See: [Multi-Machine Clusters](./distributed.md)
You can define your own Runner Backend to execute tasks with a
different form of parallelism or distributed computing platform by
defining an implementation of the
[`RunnerBackend`][labtech.types.RunnerBackend] abstract base class:
[`RunnerBackend`][labtech.runners.RunnerBackend] abstract base class:

::: labtech.types.RunnerBackend
::: labtech.runners.RunnerBackend
options:
heading_level: 4

::: labtech.types.Runner
::: labtech.runners.Runner
options:
heading_level: 4
2 changes: 2 additions & 0 deletions labtech/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ def run(self):
__version__ = '0.7.1'

from .lab import Lab
from .params import param_handler
from .tasks import task
from .types import is_task, is_task_type
from .utils import logger

__all__ = [
'is_task_type',
'is_task',
'param_handler',
'task',
'Lab',
'logger',
Expand Down
6 changes: 3 additions & 3 deletions labtech/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from . import __version__ as labtech_version
from .exceptions import CacheError, TaskNotFound
from .serialization import Serializer
from .types import Cache, ResultMeta, ResultT, Storage, Task, TaskResult, TaskT
from .serialization import DefaultSerializer
from .types import Cache, ResultMeta, ResultT, Serializer, Storage, Task, TaskResult, TaskT


class NullCache(Cache):
Expand Down Expand Up @@ -50,7 +50,7 @@ class BaseCache(Cache):
METADATA_FILENAME = 'metadata.json'

def __init__(self, *, serializer: Optional[Serializer] = None):
self.serializer = serializer or Serializer()
self.serializer = serializer or DefaultSerializer()

def cache_key(self, task: Task) -> str:
serialized_str = json.dumps(self.serializer.serialize_task(task)).encode('utf-8')
Expand Down
9 changes: 9 additions & 0 deletions labtech/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ class TaskError(LabtechError):
"""Raised for failures when handling Task objects."""


class ParamHandlerError(LabtechError):
"""Raised for failures in custom parameter handlers."""


class UnregisteredParamHandlerError(LabtechError):
"""Raised when attempting to lookup a custom parameter handler
that is not registered."""


class StorageError(LabtechError):
"""Raised for failures when interacting with Storage objects."""

Expand Down
10 changes: 6 additions & 4 deletions labtech/lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

from .exceptions import LabError, TaskNotFound
from .monitor import TaskMonitor
from .runners import ForkRunnerBackend, SerialRunnerBackend, SpawnRunnerBackend, ThreadRunnerBackend
from .params import ParamHandlerManager
from .runners import ForkRunnerBackend, RunnerBackend, SerialRunnerBackend, SpawnRunnerBackend, ThreadRunnerBackend
from .storage import LocalStorage, NullStorage
from .tasks import get_direct_dependencies
from .types import LabContext, ResultMeta, ResultT, RunnerBackend, Storage, Task, TaskT, is_task, is_task_type
from .types import LabContext, ResultMeta, ResultT, Storage, Task, TaskT, is_task, is_task_type
from .utils import OrderedSet, base_tqdm, is_ipython, logger, tqdm, tqdm_notebook


Expand Down Expand Up @@ -201,13 +202,14 @@ def run(self, tasks: Sequence[Task]) -> dict[Task, Any]:
runner = self.lab.runner_backend.build_runner(
context=self.lab.context,
max_workers=self.lab.max_workers,
param_handler_manager=ParamHandlerManager.get(),
storage=self.lab._storage,
)

task_monitor = None
if not self.disable_top:
task_monitor = TaskMonitor(
runner=runner,
get_task_infos=runner.get_task_infos,
top_format=self.top_format,
top_sort=self.top_sort,
top_n=self.top_n,
Expand Down Expand Up @@ -359,7 +361,7 @@ def __init__(self, *,
useful when troubleshooting issues running tasks on different
threads and processes.
* Any instance of a
[`RunnerBackend`][labtech.types.RunnerBackend],
[`RunnerBackend`][labtech.runners.RunnerBackend],
allowing for custom task management implementations.

For details on the differences between `'fork'` and
Expand Down
10 changes: 5 additions & 5 deletions labtech/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from datetime import datetime
from itertools import zip_longest
from string import Template
from typing import Optional, Sequence, cast
from typing import Callable, Optional, Sequence, cast

import psutil

from .exceptions import LabError
from .types import Runner, TaskMonitorInfo, TaskMonitorInfoItem, TaskMonitorInfoValue
from .types import TaskMonitorInfo, TaskMonitorInfoItem, TaskMonitorInfoValue
from .utils import tqdm


Expand Down Expand Up @@ -82,9 +82,9 @@ def show(self) -> None:

class TaskMonitor:

def __init__(self, *, runner: Runner, notebook: bool,
def __init__(self, *, get_task_infos: Callable[[], Sequence[TaskMonitorInfo]], notebook: bool,
top_format: str, top_sort: str, top_n: int):
self.runner = runner
self.get_task_infos = get_task_infos
self.top_template = Template(top_format)
self.top_sort = top_sort
self.top_sort_key = top_sort
Expand All @@ -103,7 +103,7 @@ def _top_task_lines(self) -> tuple[int, list[str]]:
# Make (shallow) copies of dictionaries to avoid mutating
# original dictionaries provided by runner.
info.copy()
for info in self.runner.get_task_infos()
for info in self.get_task_infos()
]
total_task_count = len(task_infos)

Expand Down
105 changes: 105 additions & 0 deletions labtech/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#from functools import cached_property
from inspect import isclass
from typing import Optional, Type, TypedDict

from .exceptions import ParamHandlerError, UnregisteredParamHandlerError
from .types import ParamHandler
from .utils import fully_qualified_class_name


class ParamHandlerEntry(TypedDict):
handler: ParamHandler
priority: int


class ParamHandlerManager:

def __init__(self) -> None:
self._entries: dict[str, ParamHandlerEntry] = {}
self._prioritised_handlers: Optional[list[ParamHandler]] = None

def register(self, cls: Type[ParamHandler], *, priority: int) -> None:
if not isinstance(cls, ParamHandler):
raise ParamHandlerError(
(f"Cannot register '{cls.__qualname__}' as a custom parameter handler, "
"as it does not implement all methods of the 'ParamHandler' protocol.")
)

self._entries[fully_qualified_class_name(cls)] = ParamHandlerEntry(
handler=cls(),
priority=priority,
)
# Clear cache
self._prioritised_handlers = None

def lookup(self, fq_class_name: str) -> ParamHandler:
try:
entry = self._entries[fq_class_name]
except KeyError:
raise UnregisteredParamHandlerError(fully_qualified_class_name)
return entry['handler']

def clear(self) -> None:
self._entries = {}
# Clear cache
self._prioritised_handlers = None

@property
def prioritised_handlers(self) -> list[ParamHandler]:
if self._prioritised_handlers is None:
self._prioritised_handlers = [
entry['handler'] for entry in
# Sort param handlers by priority, keeping insertion order
# where priorities are equal.
sorted(self._entries.values(), key=lambda entry: entry['priority'])
]
return self._prioritised_handlers

def instantiate(self) -> None:
global _PARAM_HANDLER_MANAGER
_PARAM_HANDLER_MANAGER = self

@staticmethod
def get() -> 'ParamHandlerManager':
return _PARAM_HANDLER_MANAGER


_PARAM_HANDLER_MANAGER = ParamHandlerManager()


def param_handler(*args, priority: int = 1000):
"""Class decorator for declaring custom parameter handlers that
can define how Labtech should handle the processing,
serialization, and deserialization of additional parameter types.

Defining a custom parameter handler is an advanced feature of
Labtech, and you are responsible for ensuring:

* The decorated class implements all methods of the
[`ParamHandler`][labtech.types.ParamHandler] protocol.
* To ensure tasks are reproducible, you should only define
handlers for custom parameter types that are **immutable and
composed only of immutable elements**.
* Because tasks are hashable representations of their parameters,
you should only define handlers for custom parameter types that
are **hashable and composed only of hashable elements**.
* Because serialized parameters will reference the module path and
class name of the custom parameter handler that was used to
serialize them, you should avoid moving or renaming custom
parameter handlers once they are in use.

Args:
priority: Determines the order in which custom parameter handlers are
applied when processing a parameter value. Lower priority values
are applied first.

"""

def decorator(cls):
ParamHandlerManager.get().register(cls, priority=priority)
return cls

if len(args) > 0 and isclass(args[0]):
return decorator(args[0], *args[1:])
else:
return decorator
3 changes: 3 additions & 0 deletions labtech/runners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .base import Runner, RunnerBackend
from .process import ForkRunnerBackend, SpawnRunnerBackend
from .serial import SerialRunnerBackend
from .thread import ThreadRunnerBackend

__all__ = [
'Runner',
'RunnerBackend',
'ForkRunnerBackend',
'SpawnRunnerBackend',
'SerialRunnerBackend',
Expand Down
Loading