From d09a7ebada0d558dcf4fc26689ab8a56b7e19793 Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sat, 29 Jul 2023 21:13:23 -0400 Subject: [PATCH] (#64) Support tasks retry & propagate raised exception For documentation, see: 1. Docstring of `task.task` 2. Tests in `tests.test_task` e.g. `test_retry_as_per_task_definition` 3. Sample usages in `tests.apps.simple_app` e.g. `append_to_file` Changelist: * Formalize serialization and deserialization * Serialize & deserialize exceptions correctly * Encapsulate retry & retry_on in a new dict 'options' * Implement serde for AsyncResult * Ensure generated file deleted after test * Add jsonpickle to toml file * Exclude `if TYPE_CHECKING:` from coverage * Add test for singleton * Add logging for worker * Wrap all constants inside `Config` class --- .coveragerc | 6 + .pylintrc | 12 +- .pylintrc.tests | 6 +- .vscode/launch.json | 13 +- README.md | 10 +- pyproject.toml | 5 +- src/aiotaskq/__init__.py | 2 +- src/aiotaskq/__main__.py | 5 +- src/aiotaskq/constants.py | 64 ++++++++- src/aiotaskq/interfaces.py | 54 ++++++++ src/aiotaskq/serde.py | 160 ++++++++++++++++++++++ src/aiotaskq/task.py | 190 ++++++++++++++++++-------- src/aiotaskq/worker.py | 99 +++++++++----- src/tests/apps/simple_app.py | 98 ++++++++++++- src/tests/conftest.py | 12 ++ src/tests/test_cli.py | 7 +- src/tests/test_concurrency.py | 6 +- src/tests/test_concurrency_manager.py | 16 +++ src/tests/test_integration.py | 2 +- src/tests/test_serde.py | 80 +++++++++++ src/tests/test_task.py | 177 +++++++++++++++++++++++- src/tests/test_worker.py | 2 +- test.sh | 9 +- 23 files changed, 901 insertions(+), 134 deletions(-) create mode 100644 src/aiotaskq/serde.py create mode 100644 src/tests/test_serde.py diff --git a/.coveragerc b/.coveragerc index 2d31e33..435a053 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,3 +3,9 @@ source = src/ parallel = True concurrency = multiprocessing sigterm = True + +[report] +exclude_lines = + pragma: no cover + if TYPE_CHECKING: + if t.TYPE_CHECKING: diff --git a/.pylintrc b/.pylintrc index c3b6ead..70335b6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -257,16 +257,16 @@ ignored-parents= max-args=10 # Maximum number of attributes for a class (see R0902). -max-attributes=7 +max-attributes=10 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 # Maximum number of branch for function / method body. -max-branches=12 +max-branches=15 # Maximum number of locals for function / method body. -max-locals=15 +max-locals=20 # Maximum number of parents for a class (see R0901). max-parents=7 @@ -458,7 +458,11 @@ good-names=i, a, b, c, - n + n, + id, + e, + s, + on # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted diff --git a/.pylintrc.tests b/.pylintrc.tests index 568e043..6ab1282 100644 --- a/.pylintrc.tests +++ b/.pylintrc.tests @@ -461,8 +461,12 @@ good-names=i, a, b, c, + e, n, - ls + t, + ls, + fo, + fi # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted diff --git a/.vscode/launch.json b/.vscode/launch.json index 0b6c3e1..7a073b2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,16 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "name": "Python: Attach", + "type": "python", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5678 + }, + "justMyCode": false + }, { "name": "Main", "type": "python", @@ -31,7 +41,8 @@ "-s", ], "request": "launch", - "console": "integratedTerminal" + "console": "integratedTerminal", + "justMyCode": false }, { "name": "Sample Worker (Simple App)", diff --git a/README.md b/README.md index 3d8cb67..b8f127a 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ import asyncio import aiotaskq -@aiotaskq.task +@aiotaskq.task() def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: @@ -132,22 +132,22 @@ import asyncio from aiotaskq import task -@task +@task() def task_1(*args, **kwargs): pass -@task +@task() def task_2(*args, **kwargs): pass -@task +@task() def task_3(*args, **kwargs): pass -@task +@task() def task_4(*args, **kwargs): pass diff --git a/pyproject.toml b/pyproject.toml index f06e373..f6fbbc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,11 +9,12 @@ build-backend = "setuptools.build_meta" requires-python = ">=3.9" dependencies = [ "aioredis >= 2.0.0, < 2.1.0", + "jsonpickle >= 3.0.0, < 3.1.0", "tomlkit >= 0.11.0, < 0.12.0", "typer >= 0.4.0, < 0.5.0", ] name = "aiotaskq" -version = "0.0.12" +version = "0.0.13" readme = "README.md" description = "A simple asynchronous task queue" authors = [ @@ -28,7 +29,7 @@ license = { file = "LICENSE" } [project.optional-dependencies] dev = [ - "black >= 22.1.0, < 22.2.0", + "black >= 22.2.0, < 23.0.0", "coverage >= 6.4.0, < 6.5.0", "mypy >= 0.931, < 1.0", "mypy-extensions >= 0.4.0, < 0.5.0", diff --git a/src/aiotaskq/__init__.py b/src/aiotaskq/__init__.py index b38e566..4d5fd35 100644 --- a/src/aiotaskq/__init__.py +++ b/src/aiotaskq/__init__.py @@ -7,7 +7,7 @@ import aiotaskq - @aiotaskq.task + @aiotaskq.task() def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: diff --git a/src/aiotaskq/__main__.py b/src/aiotaskq/__main__.py index 5bdcf56..d21a489 100755 --- a/src/aiotaskq/__main__.py +++ b/src/aiotaskq/__main__.py @@ -2,15 +2,18 @@ #!/usr/bin/env python +import logging import typing as t import typer +from . import __version__ +from .constants import Config from .interfaces import ConcurrencyType from .worker import Defaults, run_worker_forever -from . import __version__ cli = typer.Typer() +logging.basicConfig(level=Config.log_level()) def _version_callback(value: bool): diff --git a/src/aiotaskq/constants.py b/src/aiotaskq/constants.py index 6609dbb..6aa6f9b 100644 --- a/src/aiotaskq/constants.py +++ b/src/aiotaskq/constants.py @@ -1,5 +1,61 @@ -"""Module to define and store all constants used across the library.""" +""" +Module to define and store all constants used across the library. -REDIS_URL = "redis://127.0.0.1:6379" -TASKS_CHANNEL = "channel:tasks" -RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" +The public object from this module is `Config`. This object wraps +all the constants, which include: +- Variables +- Environment variables +- Static methods that return constant values +""" + +import logging +from os import environ + +from .interfaces import SerializationType + +_REDIS_URL = "redis://127.0.0.1:6379" +_TASKS_CHANNEL = "channel:tasks" +_RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" + + +class Config: + """ + Provide configuration values. + + These include: + - Variables + - Environment variables + - Static methods that return constant values + """ + + @staticmethod + def serialization_type() -> SerializationType: + """Return the serialization type as provided via env var AIOTASKQ_SERIALIZATION.""" + s: str | None = environ.get("AIOTASKQ_SERIALIZATION", SerializationType.DEFAULT.value) + return SerializationType[s.upper()] + + @staticmethod + def log_level() -> int: + """Return the log level as provided via env var LOG_LEVEL.""" + level: int = int(environ.get("AIOTASKQ_LOG_LEVEL", logging.DEBUG)) + return level + + @staticmethod + def broker_url() -> str: + """ + Return the broker url as provided via env var BROKER_URL. + + Defaults to "redis://127.0.0.1:6379". + """ + broker_url: str = environ.get("BROKER_URL", _REDIS_URL) + return broker_url + + @staticmethod + def tasks_channel() -> str: + """Return the channel name used for transporting task requests on the broker.""" + return _TASKS_CHANNEL + + @staticmethod + def results_channel_template() -> str: + """Return the template chnnale name used for transporting task results on the broker.""" + return _RESULTS_CHANNEL_TEMPLATE diff --git a/src/aiotaskq/interfaces.py b/src/aiotaskq/interfaces.py index 69584a0..bc3d2a8 100644 --- a/src/aiotaskq/interfaces.py +++ b/src/aiotaskq/interfaces.py @@ -137,3 +137,57 @@ class IWorkerManager(IWorker): """ concurrency_manager: IConcurrencyManager + + +class SerializationType(str, enum.Enum): + """Specify the types of serialization supported.""" + + JSON = "json" + DEFAULT = JSON + + +T = t.TypeVar("T") + + +class ISerialization(t.Protocol, t.Generic[T]): + """Define the interface required to serialize and deserialize a generic object.""" + + @classmethod + def serialize(cls, obj: T) -> bytes: + """Serialize any object into bytes.""" + + @classmethod + def deserialize(cls, klass: type[T], s: bytes) -> T: + """Deserialize bytes into any object.""" + + +class RetryOptions(t.TypedDict): + """ + Specify the available retry options. + + max_retries int | None: The number times to keep retrying the execution of the task + until the task executes successfully. Counting starts from + 0 so if max_retries = 2 for example, then the task will execute + 1 + 2 times (1 time for first execution, 2 times for re-try). + on tuple[type[Exception], ...]: The tuple of exception classes to retry on. The task will + will only be retried if that exception that is raised + during task execution is an instance of one of the listed + exception classes. + + Examples: + + If on=(Exception,) then any kind of exception will trigger + a retry. + + If on=(ExceptionA, ExceptionB,) and during task + execution ExceptionC was raised, then retry is not triggered. + """ + + max_retries: int | None + on: tuple[type[Exception], ...] + + +class TaskOptions(t.TypedDict): + """Specify the options available for a task.""" + + retry: RetryOptions | None diff --git a/src/aiotaskq/serde.py b/src/aiotaskq/serde.py new file mode 100644 index 0000000..5782325 --- /dev/null +++ b/src/aiotaskq/serde.py @@ -0,0 +1,160 @@ +""" +Define serialization and deserialization utilities. +""" + +import importlib +import json +import types +import typing as t + +import jsonpickle + +from .constants import Config +from .interfaces import ISerialization, SerializationType, T +from .task import AsyncResult, Task + + +class Serialization(t.Generic[T]): + """Expose the JSON serialization and deserialization logic for any object behined a simple abstraction.""" + + @classmethod + def serialize(cls, obj: "T") -> bytes: + """Serialize an object of type T into bytes via an appropriate serialization logic.""" + s_klass = _get_serde_class(obj.__class__) + return s_klass.serialize(obj) + + @classmethod + def deserialize(cls, klass: type["T"], s: bytes) -> "T": + """Deserialize bytes into an object of type T via an appropriate deserialization logic.""" + s_klass = _get_serde_class(klass) + return s_klass.deserialize(klass, s) + + +def _get_serde_class(klass: type["T"]) -> type[ISerialization["T"]]: + """Get the Serializer-Deserializer class that implements `serialize` and `deserialize`.""" + map_ = { + (Task, SerializationType.JSON): JsonTaskSerialization, + (AsyncResult, SerializationType.JSON): JsonAsyncResultSerialization, + } + if (klass, Config.serialization_type()) not in map_: + assert False, "Should not reach here" # pragma: no cover + return map_[klass, Config.serialization_type()] + + +class JsonTaskSerialization(Serialization): + """Define the JSON serialization and deserialization logic for Task.""" + + class TaskOptionsRetryOnDict(t.TypedDict): + """Define the JSON structure of the retry options of a serialized Task object.""" + + max_retries: int | None + on: str + + class TaskOptionsDict(t.TypedDict): + """Define the JSON structure of the options of a serialized Task object.""" + + retry: "JsonTaskSerialization.TaskOptionsRetryOnDict | None" + + class TaskDict(t.TypedDict): + """Define the JSON structure of a serialized Task object.""" + + func: str + task_id: str + args: tuple[t.Any, ...] + kwargs: dict + options: "JsonTaskSerialization.TaskOptionsDict" + + @classmethod + def serialize(cls, obj: "Task") -> bytes: + """Serialize a Task object to JSON bytes.""" + options: JsonTaskSerialization.TaskOptionsDict = {} + retry: JsonTaskSerialization.TaskOptionsRetryOnDict | None = None + if obj.retry is not None: + retry = { + "max_retries": obj.retry["max_retries"], + "on": jsonpickle.encode(obj.retry["on"]), + } + options["retry"] = retry + d_obj: JsonTaskSerialization.TaskDict = { + "func": { + "module": obj.__module__, + "qualname": obj.__qualname__, + }, + "task_id": obj.id, + "args": obj.args, + "kwargs": obj.kwargs, + "options": options, + } + s = f"json|{json.dumps(d_obj)}" + return s.encode("utf-8") + + @classmethod + def deserialize(cls, klass: type["Task"], s: bytes) -> "Task": + """Deserialize JSON bytes to a Task object.""" + s_type, s_obj = s.decode("utf-8").split("|", 1) + assert s_type == "json" + d_obj: JsonTaskSerialization.TaskDict = json.loads(s_obj) + + d_options: JsonTaskSerialization.TaskOptionsDict = d_obj["options"] + retry: JsonTaskSerialization.TaskOptionsRetryOnDict | None = None + if d_options.get("retry") is not None: + s_retry: "JsonTaskSerialization.TaskOptionsRetryOnDict" = d_obj["options"]["retry"] + max_retries: int | None = s_retry["max_retries"] + retry_on: tuple[type[Exception], ...] = jsonpickle.decode(s_retry["on"]) + retry = { + "max_retries": max_retries, + "on": retry_on, + } + + func_module_path, func_qualname = d_obj["func"]["module"], d_obj["func"]["qualname"] + func_module: types.ModuleType = importlib.import_module(func_module_path) + func: types.FunctionType = getattr(func_module, func_qualname).func + assert func is not None + + obj: "Task" = klass( + func=func, + task_id=d_obj["task_id"], + args=d_obj["args"], + kwargs=d_obj["kwargs"], + retry=retry, + ) + return obj + + +class JsonAsyncResultSerialization(Serialization): + """Define the JSON serialization and deserialization logic for AsyncResult.""" + + class AsyncResultDict(t.TypedDict): + """Define the JSON structure of a serialized AsyncResult.""" + + task_id: str + ready: bool + result: t.Any + error: str + + @classmethod + def serialize(cls, obj: "AsyncResult") -> bytes: + """Serialize AsyncResult object to JSON bytes.""" + error_s: str = jsonpickle.encode(obj.error) + result_json: JsonAsyncResultSerialization.AsyncResultDict = { + "task_id": obj.task_id, + "ready": obj.ready, + "result": obj.result, + "error": error_s, + } + return f"json|{json.dumps(result_json)}".encode("utf-8") + + @classmethod + def deserialize(cls, klass: type["AsyncResult"], s: bytes) -> "AsyncResult": + """Deserialize JSON bytes to a AsyncResult object.""" + s: str = s.decode("utf-8") + s_type, s_obj = s.split("|", 1) + assert s_type == "json" + obj_d: JsonAsyncResultSerialization.AsyncResultDict = json.loads(s_obj) + result: "AsyncResult" = klass( + task_id=obj_d["task_id"], + ready=obj_d["ready"], + result=obj_d["result"], + error=jsonpickle.decode(obj_d["error"]), + ) + return result diff --git a/src/aiotaskq/task.py b/src/aiotaskq/task.py index 90e8b94..572d815 100644 --- a/src/aiotaskq/task.py +++ b/src/aiotaskq/task.py @@ -1,21 +1,25 @@ """Module to define the main logic of the library.""" +# pylint: disable=cyclic-import + +import copy import inspect -import json import logging from types import ModuleType import typing as t import uuid -from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL +from .constants import Config from .exceptions import InvalidArgument, ModuleInvalidForTask -from .interfaces import IPubSub, PollResponse +from .interfaces import IPubSub, PollResponse, TaskOptions from .pubsub import PubSub +if t.TYPE_CHECKING: + from .interfaces import RetryOptions + RT = t.TypeVar("RT") P = t.ParamSpec("P") -logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -27,24 +31,42 @@ class AsyncResult(t.Generic[RT]): """ pubsub: IPubSub - _result: RT - _completed: bool = False - _task_id: str - - def __init__(self, task_id: str) -> None: + task_id: str + ready: bool = False + result: RT | None + error: Exception | None + + def __init__( + self, task_id: str, result: RT | None, ready: bool, error: Exception | None + ) -> None: """Store task_id in AsyncResult instance.""" - self._task_id = task_id - self.pubsub = PubSub.get(url=REDIS_URL, poll_interval_s=0.01) + self.task_id = task_id + self.ready = ready + self.result = result + self.error = error + self.pubsub = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) + + @classmethod + async def from_publisher(cls, task_id: str) -> "AsyncResult": + """Return the result of the task once finished.""" + from aiotaskq.serde import Serialization # pylint: disable=import-outside-toplevel + + pubsub_ = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) + async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager + await pubsub.subscribe(Config.results_channel_template().format(task_id=task_id)) + message: PollResponse = await pubsub.poll() + + logger.debug("Message: %s", message) + + result_serialized: bytes = message["data"] + result_: "AsyncResult" = Serialization.deserialize(cls, result_serialized) + return result_ - async def get(self) -> RT: + def get(self) -> RT | Exception: """Return the result of the task once finished.""" - async with self.pubsub as pubsub: # pylint: disable=not-async-context-manager - message: PollResponse - await pubsub.subscribe(RESULTS_CHANNEL_TEMPLATE.format(task_id=self._task_id)) - message = await self.pubsub.poll() - logger.debug("Message: %s", message) - _result: RT = json.loads(message["data"]) - return _result + if self.error is not None: + return self.error + return self.result class Task(t.Generic[P, RT]): @@ -61,7 +83,7 @@ def some_func(x: int, y: int) -> int: return x + y some_task = aiotaskq.task(some_func) # Or equivalently: - # @aiotaskq.task + # @aiotaskq.task() # def some_task(x: int, y: int) -> int: # return x + y @@ -73,23 +95,55 @@ def some_func(x: int, y: int) -> int: ``` """ - __qualname__: str + id: str func: t.Callable[P, RT] - - def __init__(self, func: t.Callable[P, RT]) -> None: + retry: "RetryOptions | None" + args: t.Optional[tuple[t.Any, ...]] + kwargs: t.Optional[dict] + + def __init__( + self, + func: t.Callable[P, RT], + *, + retry: "RetryOptions | None" = None, + task_id: t.Optional[str] = None, + args: t.Optional[tuple[t.Any, ...]] = None, + kwargs: t.Optional[dict] = None, + ) -> None: """ Store the underlying function and an automatically generated task_id in the Task instance. """ self.func = func + self.retry = retry + self.args = args + self.kwargs = kwargs + self.id = task_id + + # Copy metadata from the function to simulate as close as possible + + self.__module__ = self.func.__module__ + self.__qualname__ = self.func.__qualname__ + self.__name__ = self.func.__name__ def __call__(self, *args, **kwargs) -> RT: """Call the task synchronously, by directly executing the underlying function.""" return self.func(*args, **kwargs) + def with_retry(self, max_retries: int, on: tuple[type[Exception], ...]) -> "Task": + """ + Return a **copy** of self with the provided retry options. + + We return a copy so that we don't overwrite the original task definition. + """ + task_ = copy.deepcopy(self) + retry = {"max_retries": max_retries, "on": on} + task_.retry = retry + return task_ + def generate_task_id(self) -> str: """Generate a unique id for an individual call to a task.""" id_ = uuid.uuid4() - return f"{self.__qualname__}:{id_}" + return f"{self.__module__}.{self.__qualname__}:{id_}" async def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> RT: """ @@ -106,26 +160,43 @@ async def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> RT: """ # Raise error if arguments provided are invalid, before enything self._validate_arguments(task_args=args, task_kwargs=kwargs) + task_ = copy.deepcopy(self) + task_.args = args + task_.kwargs = kwargs + if task_.id is None: + task_.id = task_.generate_task_id() + # pylint: disable=protected-access + await task_.publish() + return await task_._get_result() + + async def publish(self) -> RT: + """ + Publish the task. + + At this point we expected that args & kwargs are already provided and task_id is already generated. + """ + from aiotaskq.serde import Serialization # pylint: disable=import-outside-toplevel + + assert hasattr(self, "args") and hasattr(self, "kwargs") and self.id is not None + + message: bytes = Serialization.serialize(self) - task_id: str = self.generate_task_id() - message: str = json.dumps( - { - "task_id": task_id, - "args": args, - "kwargs": kwargs, - } - ) pubsub_ = PubSub.get( - url=REDIS_URL, poll_interval_s=0.01, max_connections=10, decode_responses=True + url=Config.broker_url(), + poll_interval_s=0.01, + max_connections=10, + decode_responses=True, ) async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager - logger.debug("Publishing task [task_id=%s, message=%s]", task_id, message) - await pubsub.publish(TASKS_CHANNEL, message=message) - - logger.debug("Retrieving result for task [task_id=%s]", task_id) - async_result: AsyncResult[RT] = AsyncResult(task_id=task_id) - result: RT = await async_result.get() - + logger.debug("Publishing task [task_id=%s, message=%s]", self.id, message) + await pubsub.publish(Config.tasks_channel(), message=message) + + async def _get_result(self) -> RT: + logger.debug("Retrieving result for task [task_id=%s]", self.id) + async_result: AsyncResult[RT] = await AsyncResult.from_publisher(task_id=self.id) + result: RT | Exception = async_result.get() + if isinstance(result, Exception): + raise result return result def _validate_arguments(self, task_args: tuple, task_kwargs: dict): @@ -138,23 +209,26 @@ def _validate_arguments(self, task_args: tuple, task_kwargs: dict): ) from exc -def task(func: t.Callable[P, RT]) -> Task[P, RT]: - """Decorator to convert a callable into an aiotaskq Task instance.""" - func_module: t.Optional[ModuleType] = inspect.getmodule(func) +def task(*, options: TaskOptions | None = None) -> t.Callable[[t.Callable[P, RT]], Task[P, RT]]: + """ + Decorator to convert a callable into an aiotaskq Task instance. - if func_module is None: - raise ModuleInvalidForTask( - f'Function "{func.__name__}" is defined in an invalid module {func_module}' - ) + Args: + options (aiotaskq.interfaces.TaskOptions | None): Specify the options available for a task. + """ + + if options is None: + options = {} + + def _wrapper(func: t.Callable[P, RT]) -> Task[P, RT]: + func_module: t.Optional[ModuleType] = inspect.getmodule(func) + + if func_module is None: + raise ModuleInvalidForTask( + f'Function "{func.__name__}" is defined in an invalid module {func_module}' + ) + + task_ = Task[P, RT](func, **options) + return task_ - module_path = ".".join( - [ - p.split(".py")[0] - for p in func_module.__file__.strip("./").split("/") # type: ignore - if p != "src" - ] - ) - task_ = Task[P, RT](func) - task_.__qualname__ = f"{module_path}.{func.__name__}" - task_.__module__ = module_path - return task_ + return _wrapper diff --git a/src/aiotaskq/worker.py b/src/aiotaskq/worker.py index 6d60a2d..d1b975b 100755 --- a/src/aiotaskq/worker.py +++ b/src/aiotaskq/worker.py @@ -5,7 +5,6 @@ from functools import cached_property import importlib import inspect -import json import logging import multiprocessing import os @@ -14,15 +13,15 @@ import typing as t import types +import aioredis as redis + from .concurrency_manager import ConcurrencyManagerSingleton -from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL +from .constants import Config from .interfaces import ConcurrencyType, IConcurrencyManager, IPubSub from .pubsub import PubSub +from .serde import Serialization +from .task import AsyncResult, Task -if t.TYPE_CHECKING: # pragma: no cover - from .task import Task - -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -73,7 +72,7 @@ def _pid(self) -> int: @staticmethod def _get_child_worker_tasks_channel(pid: int) -> str: - return f"{TASKS_CHANNEL}:{pid}" + return f"{Config.tasks_channel()}:{pid}" class Defaults: @@ -118,7 +117,7 @@ def __init__( worker_rate_limit: int, poll_interval_s: float, ) -> None: - self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s) + self.pubsub: IPubSub = PubSub.get(url=Config.broker_url(), poll_interval_s=poll_interval_s) self.concurrency_manager: IConcurrencyManager = ConcurrencyManagerSingleton.get( concurrency_type=concurrency_type, concurrency=concurrency, @@ -155,7 +154,7 @@ async def _main_loop(self): async with self.pubsub as pubsub: # pylint: disable=not-async-context-manager counter = -1 - await pubsub.subscribe(TASKS_CHANNEL) + await pubsub.subscribe(Config.tasks_channel()) while True: self._logger.debug("[%s] Polling for a new task until it's available", self._pid) message = await pubsub.poll() @@ -192,7 +191,7 @@ class GruntWorker(BaseWorker): """ def __init__(self, app_import_path: str, poll_interval_s: float, worker_rate_limit: int): - self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s) + self.pubsub: IPubSub = PubSub.get(url=Config.broker_url(), poll_interval_s=poll_interval_s) self._worker_rate_limit = worker_rate_limit super().__init__(app_import_path=app_import_path) @@ -226,20 +225,13 @@ async def _main_loop(self): "[%s] Received task to from main worker [message=%s, channel=%s]", *(self._pid, message, channel), ) - task_info = json.loads(message["data"]) - task_args = task_info["args"] - task_kwargs = task_info["kwargs"] - task_id: str = task_info["task_id"] - task_func_name: str = task_id.split(":")[0].split(".")[-1] - task: "Task" = getattr(self.app, task_func_name) + task_serialized: str = message["data"] + task: "Task" = Serialization.deserialize(Task, task_serialized) # Fire and forget: execute the task and publish result task_asyncio: "asyncio.Task" = self._execute_task_and_publish( pubsub=pubsub, task=task, - task_args=task_args, - task_kwargs=task_kwargs, - task_id=task_id, semaphore=semaphore, ) asyncio.create_task(task_asyncio) @@ -248,31 +240,64 @@ async def _execute_task_and_publish( self, pubsub: IPubSub, task: "Task", - task_args: list, - task_kwargs: dict, - task_id: str, semaphore: t.Optional["asyncio.Semaphore"], ): self._logger.debug( "[%s] Executing task %s(*%s, **%s)", - *(self._pid, task_id, task_args, task_kwargs), + *(self._pid, task.id, task.args, task.kwargs), ) - if inspect.iscoroutinefunction(task.func): - task_result = await task(*task_args, **task_kwargs) - else: - task_result = task(*task_args, **task_kwargs) - # Publish the task return value - self._logger.debug( - "[%s] Publishing task result %s(*%s, **%s)", - *(self._pid, task_id, task_args, task_kwargs), - ) - task_result = json.dumps(task_result) - result_channel = RESULTS_CHANNEL_TEMPLATE.format(task_id=task_id) - await pubsub.publish(channel=result_channel, message=task_result) + retry = False + error = None + retries: int = None + retry_max: int | None = None + task_result: t.Any = None + try: + if inspect.iscoroutinefunction(task.func): + task_result = await task(*task.args, **task.kwargs) + else: + task_result = task(*task.args, **task.kwargs) + except Exception as e: # pylint: disable=broad-except + error = e + if task.retry is not None: + retry_max = task.retry["max_retries"] + if isinstance(e, task.retry["on"]): + retry = True + + finally: + # Retry if still within retry limit + if retry: + async with redis.from_url(url="redis://localhost:6379") as redis_client: + retries = int(await redis_client.get(f"retry:{task.id}") or 0) + if retry_max is not None and retries < retry_max: + retries += 1 + logger.debug( + "Task %s[%s] failed on exception %s, will retry (%s/%s)", + *(task.__qualname__, task.id, error, retries, retry_max), + ) + asyncio.create_task(task.publish()) + await redis_client.set(f"retry:{task.id}", retries) + if semaphore is not None: + semaphore.release() + return # pylint: disable=lost-exception + + if error: + # Publish error + logger.debug("Publishing error") + result = AsyncResult(task_id=task.id, ready=True, result=None, error=error) + else: + # Publish the task return value + self._logger.debug( + "[%s] Publishing task result %s(*%s, **%s)", + *(self._pid, task.id, task.args, task.kwargs), + ) + result = AsyncResult(task_id=task.id, ready=True, result=task_result, error=None) + task_serialized = Serialization.serialize(obj=result) + result_channel = Config.results_channel_template().format(task_id=task.id) + await pubsub.publish(channel=result_channel, message=task_serialized) - if semaphore is not None: - semaphore.release() + if semaphore is not None: + semaphore.release() def validate_input(app_import_path: str) -> t.Optional[str]: diff --git a/src/tests/apps/simple_app.py b/src/tests/apps/simple_app.py index 2325d6b..df85bc3 100644 --- a/src/tests/apps/simple_app.py +++ b/src/tests/apps/simple_app.py @@ -1,36 +1,122 @@ import asyncio +import logging +import os import aiotaskq +logger = logging.getLogger(__name__) -@aiotaskq.task + +class SomeException(Exception): + pass + + +class SomeException2(Exception): + pass + + +@aiotaskq.task( + options={ + "retry": { + "max_retries": 2, + "on": (SomeException,), + }, + }, +) +async def append_to_file(filename: str) -> None: + """ + Append pid to file for testing purpose and unconditionally raise SomeException. + + How is this useful for testing? + + This task is defined with retry options. By appending to a file and raising exception at the + end, we can check how many times this task has been applied and verify if the retry logic is + working. + """ + _append_pid_to_file(filename=filename) + raise SomeException("Some error") + + +@aiotaskq.task() +async def append_to_file_2(filename: str) -> None: + """ + Append pid to file for testing purpose and unconditionally raise SomeException2. + + This works exactly the same as `append_to_file` except that: + 1. We're raising a different exception. + 2. The task is not defined with the retry options + + This can help us verify that: + 1. The retry logic can also be provided during task call instead of only during task + defintion + 2. The retry logic is applied only against the specified exception classes + """ + _append_pid_to_file(filename=filename) + raise SomeException2 + + +@aiotaskq.task( + options={ + "retry": { + "max_retries": 2, + "on": (SomeException,), + }, + }, +) +async def append_to_file_first_3_times_with_error(filename: str) -> None: + """ + Append pid to file for testing purpose and *conditionally* raise SomeException. + + This works exactly the same as `append_to_file` except that we're raising + SomeException only on a certain condition. + + This can help us verify that a task will only be retried while it fails -- once + it successfully run without error, it will no longer be retried. + """ + _append_pid_to_file(filename=filename) + # If func has been called <= 2 times (file has <= 2 lines), raise SomeException2. + # Else, return without error. + with open(filename, mode="r", encoding="utf-8") as fi: + num_lines = len(fi.read().rstrip("\n").split("\n")) + if num_lines <= 2: + raise SomeException + + +def _append_pid_to_file(filename: str) -> None: + content: str = str(os.getpid()) + with open(filename, mode="a", encoding="utf-8") as fo: + fo.write(content + "\n") + fo.flush() + + +@aiotaskq.task() def echo(x): return x -@aiotaskq.task +@aiotaskq.task() async def wait(t_s: int) -> int: """Wait asynchronously for `t_s` seconds.""" await asyncio.sleep(t_s) return t_s -@aiotaskq.task +@aiotaskq.task() def add(x: int, y: int) -> int: return x + y -@aiotaskq.task +@aiotaskq.task() def power(a: int, b: int = 1) -> int: return a**b -@aiotaskq.task +@aiotaskq.task() def join(ls: list, delimiter: str = ",") -> str: return delimiter.join([str(x) for x in ls]) -@aiotaskq.task +@aiotaskq.task() def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 1814564..be8c9fa 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -7,6 +7,7 @@ import pytest from aiotaskq.interfaces import ConcurrencyType +from aiotaskq.concurrency_manager import ConcurrencyManagerSingleton from aiotaskq.worker import Defaults, run_worker_forever @@ -21,6 +22,9 @@ async def start( worker_rate_limit: int = Defaults.worker_rate_limit(), poll_interval_s: t.Optional[float] = Defaults.poll_interval_s(), ) -> None: + # Reset singleton so each test is isolated + ConcurrencyManagerSingleton.reset() + proc = multiprocessing.Process( target=lambda: run_worker_forever( app_import_path=app, @@ -66,3 +70,11 @@ def worker(): yield worker_ worker_.terminate() worker_.close() + + +@pytest.fixture +def some_file(): + filename = "./some_file.txt" + yield filename + if os.path.exists(filename): + os.remove(filename) diff --git a/src/tests/test_cli.py b/src/tests/test_cli.py index da32eb9..2dea7da 100644 --- a/src/tests/test_cli.py +++ b/src/tests/test_cli.py @@ -8,6 +8,7 @@ from aiotaskq import __version__ from aiotaskq.__main__ import _version_callback +from aiotaskq.worker import Defaults def test_root_show_proper_help_message(): @@ -59,6 +60,8 @@ def test_version(): def test_worker_show_proper_help_message(): bash_command = "aiotaskq worker --help" + default_cpu_count: int = multiprocessing.cpu_count() + default_poll_interval_s: float = Defaults.poll_interval_s() with os.popen(bash_command) as pipe: output = pipe.read() output_expected = ( @@ -70,8 +73,8 @@ def test_worker_show_proper_help_message(): " APP [required]\n" "\n" "Options:\n" - f" --concurrency INTEGER [default: {multiprocessing.cpu_count()}]\n" - " --poll-interval-s FLOAT [default: 0.01]\n" + f" --concurrency INTEGER [default: {default_cpu_count}]\n" + f" --poll-interval-s FLOAT [default: {default_poll_interval_s}]\n" " --concurrency-type [multiprocessing]\n" " [default: multiprocessing]\n" " --worker-rate-limit INTEGER [default: -1]\n" diff --git a/src/tests/test_concurrency.py b/src/tests/test_concurrency.py index b700a88..934bfe0 100644 --- a/src/tests/test_concurrency.py +++ b/src/tests/test_concurrency.py @@ -6,7 +6,7 @@ from tests.apps import simple_app -if t.TYPE_CHECKING: # pragma: no cover +if t.TYPE_CHECKING: from tests.conftest import WorkerFixture @@ -84,7 +84,9 @@ async def test_concurrent_async_tasks_return_correctly(worker: "WorkerFixture"): @pytest.mark.asyncio -async def test_async__concurrency_and_worker_rate_limit_of_1__effectively_serial(worker: "WorkerFixture"): +async def test_async__concurrency_and_worker_rate_limit_of_1__effectively_serial( + worker: "WorkerFixture", +): """Assert that if concurrency=1 & worker-rate-limit=1, tasks will effectively run serially.""" # Given that the worker cli is run with "--concurrency 1" and "--worker-rate-limit 1" options await worker.start(app=simple_app.__name__, concurrency=1, worker_rate_limit=1) diff --git a/src/tests/test_concurrency_manager.py b/src/tests/test_concurrency_manager.py index c265b94..dba7607 100644 --- a/src/tests/test_concurrency_manager.py +++ b/src/tests/test_concurrency_manager.py @@ -1,5 +1,6 @@ from aiotaskq.exceptions import ConcurrencyTypeNotSupported from aiotaskq.concurrency_manager import ConcurrencyManagerSingleton +from aiotaskq.interfaces import ConcurrencyType def test_unsupported_concurrency_type(): @@ -18,3 +19,18 @@ def test_unsupported_concurrency_type(): assert ( str(error) == 'Concurrency type "some-incorrect-concurrency-type" is not yet supported.' ) + + +def test_singleton(): + # When getting the concurrency_manager instance more than once + instance_1 = ConcurrencyManagerSingleton.get( + concurrency_type=ConcurrencyType.MULTIPROCESSING, + concurrency=4, + ) + instance_2 = ConcurrencyManagerSingleton.get( + concurrency_type=ConcurrencyType.MULTIPROCESSING, + concurrency=4, + ) + + # Then the both instances should be the identical instance + assert instance_1 is instance_2 diff --git a/src/tests/test_integration.py b/src/tests/test_integration.py index bbc5574..e799237 100644 --- a/src/tests/test_integration.py +++ b/src/tests/test_integration.py @@ -13,7 +13,7 @@ async def test_sync_and_async_parity__simple_app(worker: WorkerFixture): await worker.start(app=app.__name__, concurrency=8) # Then there should be parity between sync and async call of the tasks tests: list[tuple[Task, tuple, dict]] = [ - (simple_app.wait, tuple() ,{"t_s": 1}), + (simple_app.wait, tuple(), {"t_s": 1}), (simple_app.echo, (42,), {}), (simple_app.add, tuple(), {"x": 41, "y": 1}), (simple_app.power, (2,), {"b": 64}), diff --git a/src/tests/test_serde.py b/src/tests/test_serde.py new file mode 100644 index 0000000..313a185 --- /dev/null +++ b/src/tests/test_serde.py @@ -0,0 +1,80 @@ +from importlib import import_module +import json + +from aiotaskq.serde import JsonTaskSerialization +from aiotaskq.task import Task, task + + +@task() +def some_task(a: int, b: int) -> int: + return a * b + + +class SomeException(Exception): + pass + + +@task( + options={ + "retry": {"max_retries": 1, "on": (SomeException,)}, + }, +) +def some_task_2(a: int, b: int) -> int: + return a * b + + +def test_serialize_task_to_json(): + # Given a task definition + assert isinstance(some_task, Task) + + # When the task has been serialized to json + task_serialized = JsonTaskSerialization.serialize(some_task) + + # Then the resulting json should be in correct type and format + assert isinstance(task_serialized, bytes) + # And should able to deserialized into the same task + task_deserialized = JsonTaskSerialization.deserialize(Task, task_serialized) + task_format_str, task_serialized_str = task_serialized.decode("utf-8").split("|", 1) + assert task_format_str == "json" + # And the task should be serialized into correct json + task_serialized_dict = json.loads(task_serialized_str) + assert task_serialized_dict == { + "func": {"module": "tests.test_serde", "qualname": "some_task"}, + "task_id": None, + "args": None, + "kwargs": None, + "options": {}, + } + # And should be functionally the same as the original task + assert task_deserialized.func(1, 2) == some_task.func(1, 2) + assert task_deserialized(3, 4) == some_task(3, 4) + + +def test_serialize_task_to_json__with_retry_param(): + # Given a task definition + assert isinstance(some_task_2, Task) + + # When the task has been serialized to json + task_serialized = JsonTaskSerialization.serialize(some_task_2) + + # Then the task should be serialized in the correct type and format + assert isinstance(task_serialized, bytes) + task_format_str, task_serialized_str = task_serialized.decode("utf-8").split("|", 1) + assert task_format_str == "json" + # And the serialized task should contain information about retry + task_serialized_dict = json.loads(task_serialized_str) + assert task_serialized_dict == { + "func": {"module": "tests.test_serde", "qualname": "some_task_2"}, + "task_id": None, + "args": None, + "kwargs": None, + "options": { + "retry": { + "max_retries": 1, + "on": '{"py/tuple": [{"py/type": "tests.test_serde.SomeException"}]}', + }, + }, + } + # And the deserialized task should function the same as the original + task_deserialized = import_module(task_serialized_dict["func"]["module"]).some_task_2 + assert task_deserialized(2, 3) == some_task_2(2, 3) diff --git a/src/tests/test_task.py b/src/tests/test_task.py index 171a16c..c350faf 100644 --- a/src/tests/test_task.py +++ b/src/tests/test_task.py @@ -1,3 +1,4 @@ +import os from typing import TYPE_CHECKING import pytest @@ -5,7 +6,7 @@ from aiotaskq.exceptions import InvalidArgument from tests.apps import simple_app -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from aiotaskq.task import Task from tests.conftest import WorkerFixture @@ -50,6 +51,176 @@ async def test_invalid_argument_provided_to_apply_async( error = exc finally: assert str(error) == ( - f"These arguments are invalid: args={invalid_args}," - f" kwargs={invalid_kwargs}" + f"These arguments are invalid: args={invalid_args}," f" kwargs={invalid_kwargs}" ) + + +@pytest.mark.asyncio +async def test_retry_as_per_task_definition(worker: "WorkerFixture", some_file: str): + # Given a worker running in the background + await worker.start(simple_app.__name__, concurrency=1) + # And a task defined with retry configuration + assert simple_app.append_to_file.retry["max_retries"] == 2 + assert simple_app.append_to_file.retry["on"] == (simple_app.SomeException,) + # And the task will raise an exception when called as a function + exception = None + try: + await simple_app.append_to_file(some_file) + except simple_app.SomeException as e: + exception = e + finally: + if os.path.exists(some_file): + os.remove(some_file) + assert isinstance(exception, simple_app.SomeException) + + # When the task has been applied + exception = None + try: + await simple_app.append_to_file.apply_async(filename=some_file) + except simple_app.SomeException as exc: + exception = exc + finally: + # Then the task should be retried as many times as configured + with open(some_file, encoding="utf-8") as fi: + assert len(fi.readlines()) == 1 + 2, f"file: {fi.readlines()}" # First call + 2 retries + assert isinstance(exception, simple_app.SomeException) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "retry_on,retries_expected", + [ + ( + ( + simple_app.SomeException, + simple_app.SomeException2, + ), + 1, + ), + ( + ( + simple_app.SomeException, + simple_app.SomeException2, + ), + 2, + ), + ], +) +async def test_retry_as_per_task_call( + worker: "WorkerFixture", + retry_on: tuple[type[Exception], ...], + retries_expected: int, + some_file: str, +): + # Given a worker running in the background + await worker.start(simple_app.__name__, concurrency=1) + # And a task defined WITHOUT retry configuration + assert simple_app.append_to_file_2.retry is None + # And the task will raise an exception when called + exception = None + try: + await simple_app.append_to_file_2.func(some_file) + except simple_app.SomeException2 as e: + exception = e + finally: + assert isinstance(exception, simple_app.SomeException2) + if os.path.exists(some_file): + os.remove(some_file) + + # When the task is applied with retry option + exception = None + try: + await simple_app.append_to_file_2.with_retry( + max_retries=retries_expected, + on=retry_on, + ).apply_async(filename=some_file) + except simple_app.SomeException2 as e: + exception = e + finally: + # Then the task should be retried as many times as requested + with open(some_file, encoding="utf-8") as fi: + assert ( + len(fi.readlines()) == 1 + retries_expected + ) # First call + `retries_expected` retries + # And the task should fail with the expected exception + assert isinstance(exception, simple_app.SomeException2) + + +@pytest.mark.asyncio +async def test_no_retry_as_per_task_call(worker: "WorkerFixture", some_file: str): + # Given a worker running in the background + await worker.start(simple_app.__name__, concurrency=1) + # And a task defined WITH retry configuration + assert simple_app.append_to_file.retry is not None + # And the task will raise an exception when called + exception = None + try: + await simple_app.append_to_file.func(some_file) + except simple_app.SomeException as e: + exception = e + finally: + assert isinstance(exception, simple_app.SomeException) + if os.path.exists(some_file): + os.remove(some_file) + + # When the task is applied with retry option + # where retry["on"] doesn't include the exception that the task raises + exception = None + retry_on_new = (simple_app.SomeException2,) + assert [exc not in simple_app.append_to_file.retry["on"] for exc in retry_on_new] + try: + await simple_app.append_to_file.with_retry( + max_retries=2, + on=retry_on_new, + ).apply_async(filename=some_file) + except simple_app.SomeException as e: + exception = e + finally: + # Then the task should NOT be retried + with open(some_file, encoding="utf-8") as fi: + assert len(fi.readlines()) == 1 + # And the task should fail with the expected exception + assert isinstance(exception, simple_app.SomeException) + + +@pytest.mark.asyncio +async def test_retry_until_successful(worker: "WorkerFixture", some_file: str): + """Assert that task will stop being retried once it's successfully executed without error.""" + # Given a worker running in the background + await worker.start(simple_app.__name__, concurrency=1) + # And a task defined WITH retry max_retries = 2 + assert simple_app.append_to_file_first_3_times_with_error.retry["max_retries"] == 2 + # And the task will raise an exception when called until it's called for the 3rd time + exception = None + try: + # Call for the first time + await simple_app.append_to_file_first_3_times_with_error.func(some_file) + except simple_app.SomeException as e: + exception = e + finally: + assert exception is not None + exception = None + try: + # Call for the second time + await simple_app.append_to_file_first_3_times_with_error.func(some_file) + except simple_app.SomeException as e: + exception = e + finally: + assert exception is not None + # Call for the third time (Should no longer raise error) + await simple_app.append_to_file_first_3_times_with_error.func(some_file) + with open(some_file, mode="r", encoding="utf-8") as fi: + num_lines = len(fi.read().rstrip("\n").split("\n")) + assert num_lines == 3 + # Delete the file + if os.path.exists(some_file): + os.remove(some_file) + + # When the task is applied + await simple_app.append_to_file_first_3_times_with_error.apply_async(filename=some_file) + + # Then the task should applied successfully with no error + # After having been retried 2 times + with open(some_file, mode="r", encoding="utf-8") as fi: + num_lines = len(fi.read().rstrip("\n").split("\n")) + assert num_lines == 3 # (first call + 2 retries) diff --git a/src/tests/test_worker.py b/src/tests/test_worker.py index 55d9a90..22b8d6d 100644 --- a/src/tests/test_worker.py +++ b/src/tests/test_worker.py @@ -8,7 +8,7 @@ from aiotaskq.interfaces import ConcurrencyType from aiotaskq.worker import validate_input -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from tests.conftest import WorkerFixture diff --git a/test.sh b/test.sh index 5fa3b6b..b703925 100755 --- a/test.sh +++ b/test.sh @@ -1,17 +1,16 @@ -pip install --upgrade pip -pip install -e .[dev] - coverage erase if [ -z $1 ]; then + pip install --upgrade pip + pip install -e .[dev] coverage run -m pytest -v else - coverage run -m pytest -v -k $1 + coverage run -m pytest --log-level=DEBUG --log-cli-level=DEBUG -s -vvv -k $1 fi failed=$? -coverage combine +coverage combine --quiet exit $failed