From 4d4c0c084175e0363964192422280fd90dafe9ad Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sat, 29 Jul 2023 21:13:23 -0400 Subject: [PATCH 1/8] (#64) Support tasks retry & propagate raised exception For documentation, see: 1. Docstring of `task.task` 2. Tests in `tests.test_task` e.g. `test_retry_as_per_task_definition` 3. Sample usages in `tests.apps.simple_app` e.g. `append_to_file` Changelist: * Formalize serialization and deserialization * Serialize & deserialize exceptions correctly * Encapsulate retry & retry_on in a new dict 'options' * Implement serde for AsyncResult * Ensure generated file deleted after test * Add jsonpickle to toml file * Exclude `if TYPE_CHECKING:` from coverage * Add test for singleton * Add logging for worker * Wrap all constants inside `Config` class Signed-off-by: Imran Ariffin --- .coveragerc | 6 + .pylintrc | 12 +- .pylintrc.tests | 6 +- .vscode/launch.json | 13 +- README.md | 10 +- pyproject.toml | 5 +- src/aiotaskq/__init__.py | 2 +- src/aiotaskq/__main__.py | 5 +- src/aiotaskq/constants.py | 64 ++++++++- src/aiotaskq/interfaces.py | 54 ++++++++ src/aiotaskq/serde.py | 160 ++++++++++++++++++++++ src/aiotaskq/task.py | 190 ++++++++++++++++++-------- src/aiotaskq/worker.py | 99 +++++++++----- src/tests/apps/simple_app.py | 98 ++++++++++++- src/tests/conftest.py | 12 ++ src/tests/test_cli.py | 7 +- src/tests/test_concurrency.py | 6 +- src/tests/test_concurrency_manager.py | 16 +++ src/tests/test_integration.py | 2 +- src/tests/test_serde.py | 80 +++++++++++ src/tests/test_task.py | 177 +++++++++++++++++++++++- src/tests/test_worker.py | 2 +- test.sh | 9 +- 23 files changed, 901 insertions(+), 134 deletions(-) create mode 100644 src/aiotaskq/serde.py create mode 100644 src/tests/test_serde.py diff --git a/.coveragerc b/.coveragerc index 2d31e33..435a053 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,3 +3,9 @@ source = src/ parallel = True concurrency = multiprocessing sigterm = True + +[report] +exclude_lines = + pragma: no cover + if TYPE_CHECKING: + if t.TYPE_CHECKING: diff --git a/.pylintrc b/.pylintrc index c3b6ead..70335b6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -257,16 +257,16 @@ ignored-parents= max-args=10 # Maximum number of attributes for a class (see R0902). -max-attributes=7 +max-attributes=10 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 # Maximum number of branch for function / method body. -max-branches=12 +max-branches=15 # Maximum number of locals for function / method body. -max-locals=15 +max-locals=20 # Maximum number of parents for a class (see R0901). max-parents=7 @@ -458,7 +458,11 @@ good-names=i, a, b, c, - n + n, + id, + e, + s, + on # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted diff --git a/.pylintrc.tests b/.pylintrc.tests index 568e043..6ab1282 100644 --- a/.pylintrc.tests +++ b/.pylintrc.tests @@ -461,8 +461,12 @@ good-names=i, a, b, c, + e, n, - ls + t, + ls, + fo, + fi # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted diff --git a/.vscode/launch.json b/.vscode/launch.json index 0b6c3e1..7a073b2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,16 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "name": "Python: Attach", + "type": "python", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5678 + }, + "justMyCode": false + }, { "name": "Main", "type": "python", @@ -31,7 +41,8 @@ "-s", ], "request": "launch", - "console": "integratedTerminal" + "console": "integratedTerminal", + "justMyCode": false }, { "name": "Sample Worker (Simple App)", diff --git a/README.md b/README.md index 3d8cb67..b8f127a 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ import asyncio import aiotaskq -@aiotaskq.task +@aiotaskq.task() def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: @@ -132,22 +132,22 @@ import asyncio from aiotaskq import task -@task +@task() def task_1(*args, **kwargs): pass -@task +@task() def task_2(*args, **kwargs): pass -@task +@task() def task_3(*args, **kwargs): pass -@task +@task() def task_4(*args, **kwargs): pass diff --git a/pyproject.toml b/pyproject.toml index f06e373..f6fbbc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,11 +9,12 @@ build-backend = "setuptools.build_meta" requires-python = ">=3.9" dependencies = [ "aioredis >= 2.0.0, < 2.1.0", + "jsonpickle >= 3.0.0, < 3.1.0", "tomlkit >= 0.11.0, < 0.12.0", "typer >= 0.4.0, < 0.5.0", ] name = "aiotaskq" -version = "0.0.12" +version = "0.0.13" readme = "README.md" description = "A simple asynchronous task queue" authors = [ @@ -28,7 +29,7 @@ license = { file = "LICENSE" } [project.optional-dependencies] dev = [ - "black >= 22.1.0, < 22.2.0", + "black >= 22.2.0, < 23.0.0", "coverage >= 6.4.0, < 6.5.0", "mypy >= 0.931, < 1.0", "mypy-extensions >= 0.4.0, < 0.5.0", diff --git a/src/aiotaskq/__init__.py b/src/aiotaskq/__init__.py index b38e566..4d5fd35 100644 --- a/src/aiotaskq/__init__.py +++ b/src/aiotaskq/__init__.py @@ -7,7 +7,7 @@ import aiotaskq - @aiotaskq.task + @aiotaskq.task() def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: diff --git a/src/aiotaskq/__main__.py b/src/aiotaskq/__main__.py index 5bdcf56..d21a489 100755 --- a/src/aiotaskq/__main__.py +++ b/src/aiotaskq/__main__.py @@ -2,15 +2,18 @@ #!/usr/bin/env python +import logging import typing as t import typer +from . import __version__ +from .constants import Config from .interfaces import ConcurrencyType from .worker import Defaults, run_worker_forever -from . import __version__ cli = typer.Typer() +logging.basicConfig(level=Config.log_level()) def _version_callback(value: bool): diff --git a/src/aiotaskq/constants.py b/src/aiotaskq/constants.py index 6609dbb..6aa6f9b 100644 --- a/src/aiotaskq/constants.py +++ b/src/aiotaskq/constants.py @@ -1,5 +1,61 @@ -"""Module to define and store all constants used across the library.""" +""" +Module to define and store all constants used across the library. -REDIS_URL = "redis://127.0.0.1:6379" -TASKS_CHANNEL = "channel:tasks" -RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" +The public object from this module is `Config`. This object wraps +all the constants, which include: +- Variables +- Environment variables +- Static methods that return constant values +""" + +import logging +from os import environ + +from .interfaces import SerializationType + +_REDIS_URL = "redis://127.0.0.1:6379" +_TASKS_CHANNEL = "channel:tasks" +_RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" + + +class Config: + """ + Provide configuration values. + + These include: + - Variables + - Environment variables + - Static methods that return constant values + """ + + @staticmethod + def serialization_type() -> SerializationType: + """Return the serialization type as provided via env var AIOTASKQ_SERIALIZATION.""" + s: str | None = environ.get("AIOTASKQ_SERIALIZATION", SerializationType.DEFAULT.value) + return SerializationType[s.upper()] + + @staticmethod + def log_level() -> int: + """Return the log level as provided via env var LOG_LEVEL.""" + level: int = int(environ.get("AIOTASKQ_LOG_LEVEL", logging.DEBUG)) + return level + + @staticmethod + def broker_url() -> str: + """ + Return the broker url as provided via env var BROKER_URL. + + Defaults to "redis://127.0.0.1:6379". + """ + broker_url: str = environ.get("BROKER_URL", _REDIS_URL) + return broker_url + + @staticmethod + def tasks_channel() -> str: + """Return the channel name used for transporting task requests on the broker.""" + return _TASKS_CHANNEL + + @staticmethod + def results_channel_template() -> str: + """Return the template chnnale name used for transporting task results on the broker.""" + return _RESULTS_CHANNEL_TEMPLATE diff --git a/src/aiotaskq/interfaces.py b/src/aiotaskq/interfaces.py index 69584a0..bc3d2a8 100644 --- a/src/aiotaskq/interfaces.py +++ b/src/aiotaskq/interfaces.py @@ -137,3 +137,57 @@ class IWorkerManager(IWorker): """ concurrency_manager: IConcurrencyManager + + +class SerializationType(str, enum.Enum): + """Specify the types of serialization supported.""" + + JSON = "json" + DEFAULT = JSON + + +T = t.TypeVar("T") + + +class ISerialization(t.Protocol, t.Generic[T]): + """Define the interface required to serialize and deserialize a generic object.""" + + @classmethod + def serialize(cls, obj: T) -> bytes: + """Serialize any object into bytes.""" + + @classmethod + def deserialize(cls, klass: type[T], s: bytes) -> T: + """Deserialize bytes into any object.""" + + +class RetryOptions(t.TypedDict): + """ + Specify the available retry options. + + max_retries int | None: The number times to keep retrying the execution of the task + until the task executes successfully. Counting starts from + 0 so if max_retries = 2 for example, then the task will execute + 1 + 2 times (1 time for first execution, 2 times for re-try). + on tuple[type[Exception], ...]: The tuple of exception classes to retry on. The task will + will only be retried if that exception that is raised + during task execution is an instance of one of the listed + exception classes. + + Examples: + + If on=(Exception,) then any kind of exception will trigger + a retry. + + If on=(ExceptionA, ExceptionB,) and during task + execution ExceptionC was raised, then retry is not triggered. + """ + + max_retries: int | None + on: tuple[type[Exception], ...] + + +class TaskOptions(t.TypedDict): + """Specify the options available for a task.""" + + retry: RetryOptions | None diff --git a/src/aiotaskq/serde.py b/src/aiotaskq/serde.py new file mode 100644 index 0000000..5782325 --- /dev/null +++ b/src/aiotaskq/serde.py @@ -0,0 +1,160 @@ +""" +Define serialization and deserialization utilities. +""" + +import importlib +import json +import types +import typing as t + +import jsonpickle + +from .constants import Config +from .interfaces import ISerialization, SerializationType, T +from .task import AsyncResult, Task + + +class Serialization(t.Generic[T]): + """Expose the JSON serialization and deserialization logic for any object behined a simple abstraction.""" + + @classmethod + def serialize(cls, obj: "T") -> bytes: + """Serialize an object of type T into bytes via an appropriate serialization logic.""" + s_klass = _get_serde_class(obj.__class__) + return s_klass.serialize(obj) + + @classmethod + def deserialize(cls, klass: type["T"], s: bytes) -> "T": + """Deserialize bytes into an object of type T via an appropriate deserialization logic.""" + s_klass = _get_serde_class(klass) + return s_klass.deserialize(klass, s) + + +def _get_serde_class(klass: type["T"]) -> type[ISerialization["T"]]: + """Get the Serializer-Deserializer class that implements `serialize` and `deserialize`.""" + map_ = { + (Task, SerializationType.JSON): JsonTaskSerialization, + (AsyncResult, SerializationType.JSON): JsonAsyncResultSerialization, + } + if (klass, Config.serialization_type()) not in map_: + assert False, "Should not reach here" # pragma: no cover + return map_[klass, Config.serialization_type()] + + +class JsonTaskSerialization(Serialization): + """Define the JSON serialization and deserialization logic for Task.""" + + class TaskOptionsRetryOnDict(t.TypedDict): + """Define the JSON structure of the retry options of a serialized Task object.""" + + max_retries: int | None + on: str + + class TaskOptionsDict(t.TypedDict): + """Define the JSON structure of the options of a serialized Task object.""" + + retry: "JsonTaskSerialization.TaskOptionsRetryOnDict | None" + + class TaskDict(t.TypedDict): + """Define the JSON structure of a serialized Task object.""" + + func: str + task_id: str + args: tuple[t.Any, ...] + kwargs: dict + options: "JsonTaskSerialization.TaskOptionsDict" + + @classmethod + def serialize(cls, obj: "Task") -> bytes: + """Serialize a Task object to JSON bytes.""" + options: JsonTaskSerialization.TaskOptionsDict = {} + retry: JsonTaskSerialization.TaskOptionsRetryOnDict | None = None + if obj.retry is not None: + retry = { + "max_retries": obj.retry["max_retries"], + "on": jsonpickle.encode(obj.retry["on"]), + } + options["retry"] = retry + d_obj: JsonTaskSerialization.TaskDict = { + "func": { + "module": obj.__module__, + "qualname": obj.__qualname__, + }, + "task_id": obj.id, + "args": obj.args, + "kwargs": obj.kwargs, + "options": options, + } + s = f"json|{json.dumps(d_obj)}" + return s.encode("utf-8") + + @classmethod + def deserialize(cls, klass: type["Task"], s: bytes) -> "Task": + """Deserialize JSON bytes to a Task object.""" + s_type, s_obj = s.decode("utf-8").split("|", 1) + assert s_type == "json" + d_obj: JsonTaskSerialization.TaskDict = json.loads(s_obj) + + d_options: JsonTaskSerialization.TaskOptionsDict = d_obj["options"] + retry: JsonTaskSerialization.TaskOptionsRetryOnDict | None = None + if d_options.get("retry") is not None: + s_retry: "JsonTaskSerialization.TaskOptionsRetryOnDict" = d_obj["options"]["retry"] + max_retries: int | None = s_retry["max_retries"] + retry_on: tuple[type[Exception], ...] = jsonpickle.decode(s_retry["on"]) + retry = { + "max_retries": max_retries, + "on": retry_on, + } + + func_module_path, func_qualname = d_obj["func"]["module"], d_obj["func"]["qualname"] + func_module: types.ModuleType = importlib.import_module(func_module_path) + func: types.FunctionType = getattr(func_module, func_qualname).func + assert func is not None + + obj: "Task" = klass( + func=func, + task_id=d_obj["task_id"], + args=d_obj["args"], + kwargs=d_obj["kwargs"], + retry=retry, + ) + return obj + + +class JsonAsyncResultSerialization(Serialization): + """Define the JSON serialization and deserialization logic for AsyncResult.""" + + class AsyncResultDict(t.TypedDict): + """Define the JSON structure of a serialized AsyncResult.""" + + task_id: str + ready: bool + result: t.Any + error: str + + @classmethod + def serialize(cls, obj: "AsyncResult") -> bytes: + """Serialize AsyncResult object to JSON bytes.""" + error_s: str = jsonpickle.encode(obj.error) + result_json: JsonAsyncResultSerialization.AsyncResultDict = { + "task_id": obj.task_id, + "ready": obj.ready, + "result": obj.result, + "error": error_s, + } + return f"json|{json.dumps(result_json)}".encode("utf-8") + + @classmethod + def deserialize(cls, klass: type["AsyncResult"], s: bytes) -> "AsyncResult": + """Deserialize JSON bytes to a AsyncResult object.""" + s: str = s.decode("utf-8") + s_type, s_obj = s.split("|", 1) + assert s_type == "json" + obj_d: JsonAsyncResultSerialization.AsyncResultDict = json.loads(s_obj) + result: "AsyncResult" = klass( + task_id=obj_d["task_id"], + ready=obj_d["ready"], + result=obj_d["result"], + error=jsonpickle.decode(obj_d["error"]), + ) + return result diff --git a/src/aiotaskq/task.py b/src/aiotaskq/task.py index 90e8b94..572d815 100644 --- a/src/aiotaskq/task.py +++ b/src/aiotaskq/task.py @@ -1,21 +1,25 @@ """Module to define the main logic of the library.""" +# pylint: disable=cyclic-import + +import copy import inspect -import json import logging from types import ModuleType import typing as t import uuid -from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL +from .constants import Config from .exceptions import InvalidArgument, ModuleInvalidForTask -from .interfaces import IPubSub, PollResponse +from .interfaces import IPubSub, PollResponse, TaskOptions from .pubsub import PubSub +if t.TYPE_CHECKING: + from .interfaces import RetryOptions + RT = t.TypeVar("RT") P = t.ParamSpec("P") -logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -27,24 +31,42 @@ class AsyncResult(t.Generic[RT]): """ pubsub: IPubSub - _result: RT - _completed: bool = False - _task_id: str - - def __init__(self, task_id: str) -> None: + task_id: str + ready: bool = False + result: RT | None + error: Exception | None + + def __init__( + self, task_id: str, result: RT | None, ready: bool, error: Exception | None + ) -> None: """Store task_id in AsyncResult instance.""" - self._task_id = task_id - self.pubsub = PubSub.get(url=REDIS_URL, poll_interval_s=0.01) + self.task_id = task_id + self.ready = ready + self.result = result + self.error = error + self.pubsub = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) + + @classmethod + async def from_publisher(cls, task_id: str) -> "AsyncResult": + """Return the result of the task once finished.""" + from aiotaskq.serde import Serialization # pylint: disable=import-outside-toplevel + + pubsub_ = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) + async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager + await pubsub.subscribe(Config.results_channel_template().format(task_id=task_id)) + message: PollResponse = await pubsub.poll() + + logger.debug("Message: %s", message) + + result_serialized: bytes = message["data"] + result_: "AsyncResult" = Serialization.deserialize(cls, result_serialized) + return result_ - async def get(self) -> RT: + def get(self) -> RT | Exception: """Return the result of the task once finished.""" - async with self.pubsub as pubsub: # pylint: disable=not-async-context-manager - message: PollResponse - await pubsub.subscribe(RESULTS_CHANNEL_TEMPLATE.format(task_id=self._task_id)) - message = await self.pubsub.poll() - logger.debug("Message: %s", message) - _result: RT = json.loads(message["data"]) - return _result + if self.error is not None: + return self.error + return self.result class Task(t.Generic[P, RT]): @@ -61,7 +83,7 @@ def some_func(x: int, y: int) -> int: return x + y some_task = aiotaskq.task(some_func) # Or equivalently: - # @aiotaskq.task + # @aiotaskq.task() # def some_task(x: int, y: int) -> int: # return x + y @@ -73,23 +95,55 @@ def some_func(x: int, y: int) -> int: ``` """ - __qualname__: str + id: str func: t.Callable[P, RT] - - def __init__(self, func: t.Callable[P, RT]) -> None: + retry: "RetryOptions | None" + args: t.Optional[tuple[t.Any, ...]] + kwargs: t.Optional[dict] + + def __init__( + self, + func: t.Callable[P, RT], + *, + retry: "RetryOptions | None" = None, + task_id: t.Optional[str] = None, + args: t.Optional[tuple[t.Any, ...]] = None, + kwargs: t.Optional[dict] = None, + ) -> None: """ Store the underlying function and an automatically generated task_id in the Task instance. """ self.func = func + self.retry = retry + self.args = args + self.kwargs = kwargs + self.id = task_id + + # Copy metadata from the function to simulate as close as possible + + self.__module__ = self.func.__module__ + self.__qualname__ = self.func.__qualname__ + self.__name__ = self.func.__name__ def __call__(self, *args, **kwargs) -> RT: """Call the task synchronously, by directly executing the underlying function.""" return self.func(*args, **kwargs) + def with_retry(self, max_retries: int, on: tuple[type[Exception], ...]) -> "Task": + """ + Return a **copy** of self with the provided retry options. + + We return a copy so that we don't overwrite the original task definition. + """ + task_ = copy.deepcopy(self) + retry = {"max_retries": max_retries, "on": on} + task_.retry = retry + return task_ + def generate_task_id(self) -> str: """Generate a unique id for an individual call to a task.""" id_ = uuid.uuid4() - return f"{self.__qualname__}:{id_}" + return f"{self.__module__}.{self.__qualname__}:{id_}" async def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> RT: """ @@ -106,26 +160,43 @@ async def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> RT: """ # Raise error if arguments provided are invalid, before enything self._validate_arguments(task_args=args, task_kwargs=kwargs) + task_ = copy.deepcopy(self) + task_.args = args + task_.kwargs = kwargs + if task_.id is None: + task_.id = task_.generate_task_id() + # pylint: disable=protected-access + await task_.publish() + return await task_._get_result() + + async def publish(self) -> RT: + """ + Publish the task. + + At this point we expected that args & kwargs are already provided and task_id is already generated. + """ + from aiotaskq.serde import Serialization # pylint: disable=import-outside-toplevel + + assert hasattr(self, "args") and hasattr(self, "kwargs") and self.id is not None + + message: bytes = Serialization.serialize(self) - task_id: str = self.generate_task_id() - message: str = json.dumps( - { - "task_id": task_id, - "args": args, - "kwargs": kwargs, - } - ) pubsub_ = PubSub.get( - url=REDIS_URL, poll_interval_s=0.01, max_connections=10, decode_responses=True + url=Config.broker_url(), + poll_interval_s=0.01, + max_connections=10, + decode_responses=True, ) async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager - logger.debug("Publishing task [task_id=%s, message=%s]", task_id, message) - await pubsub.publish(TASKS_CHANNEL, message=message) - - logger.debug("Retrieving result for task [task_id=%s]", task_id) - async_result: AsyncResult[RT] = AsyncResult(task_id=task_id) - result: RT = await async_result.get() - + logger.debug("Publishing task [task_id=%s, message=%s]", self.id, message) + await pubsub.publish(Config.tasks_channel(), message=message) + + async def _get_result(self) -> RT: + logger.debug("Retrieving result for task [task_id=%s]", self.id) + async_result: AsyncResult[RT] = await AsyncResult.from_publisher(task_id=self.id) + result: RT | Exception = async_result.get() + if isinstance(result, Exception): + raise result return result def _validate_arguments(self, task_args: tuple, task_kwargs: dict): @@ -138,23 +209,26 @@ def _validate_arguments(self, task_args: tuple, task_kwargs: dict): ) from exc -def task(func: t.Callable[P, RT]) -> Task[P, RT]: - """Decorator to convert a callable into an aiotaskq Task instance.""" - func_module: t.Optional[ModuleType] = inspect.getmodule(func) +def task(*, options: TaskOptions | None = None) -> t.Callable[[t.Callable[P, RT]], Task[P, RT]]: + """ + Decorator to convert a callable into an aiotaskq Task instance. - if func_module is None: - raise ModuleInvalidForTask( - f'Function "{func.__name__}" is defined in an invalid module {func_module}' - ) + Args: + options (aiotaskq.interfaces.TaskOptions | None): Specify the options available for a task. + """ + + if options is None: + options = {} + + def _wrapper(func: t.Callable[P, RT]) -> Task[P, RT]: + func_module: t.Optional[ModuleType] = inspect.getmodule(func) + + if func_module is None: + raise ModuleInvalidForTask( + f'Function "{func.__name__}" is defined in an invalid module {func_module}' + ) + + task_ = Task[P, RT](func, **options) + return task_ - module_path = ".".join( - [ - p.split(".py")[0] - for p in func_module.__file__.strip("./").split("/") # type: ignore - if p != "src" - ] - ) - task_ = Task[P, RT](func) - task_.__qualname__ = f"{module_path}.{func.__name__}" - task_.__module__ = module_path - return task_ + return _wrapper diff --git a/src/aiotaskq/worker.py b/src/aiotaskq/worker.py index 6d60a2d..d1b975b 100755 --- a/src/aiotaskq/worker.py +++ b/src/aiotaskq/worker.py @@ -5,7 +5,6 @@ from functools import cached_property import importlib import inspect -import json import logging import multiprocessing import os @@ -14,15 +13,15 @@ import typing as t import types +import aioredis as redis + from .concurrency_manager import ConcurrencyManagerSingleton -from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL +from .constants import Config from .interfaces import ConcurrencyType, IConcurrencyManager, IPubSub from .pubsub import PubSub +from .serde import Serialization +from .task import AsyncResult, Task -if t.TYPE_CHECKING: # pragma: no cover - from .task import Task - -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -73,7 +72,7 @@ def _pid(self) -> int: @staticmethod def _get_child_worker_tasks_channel(pid: int) -> str: - return f"{TASKS_CHANNEL}:{pid}" + return f"{Config.tasks_channel()}:{pid}" class Defaults: @@ -118,7 +117,7 @@ def __init__( worker_rate_limit: int, poll_interval_s: float, ) -> None: - self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s) + self.pubsub: IPubSub = PubSub.get(url=Config.broker_url(), poll_interval_s=poll_interval_s) self.concurrency_manager: IConcurrencyManager = ConcurrencyManagerSingleton.get( concurrency_type=concurrency_type, concurrency=concurrency, @@ -155,7 +154,7 @@ async def _main_loop(self): async with self.pubsub as pubsub: # pylint: disable=not-async-context-manager counter = -1 - await pubsub.subscribe(TASKS_CHANNEL) + await pubsub.subscribe(Config.tasks_channel()) while True: self._logger.debug("[%s] Polling for a new task until it's available", self._pid) message = await pubsub.poll() @@ -192,7 +191,7 @@ class GruntWorker(BaseWorker): """ def __init__(self, app_import_path: str, poll_interval_s: float, worker_rate_limit: int): - self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s) + self.pubsub: IPubSub = PubSub.get(url=Config.broker_url(), poll_interval_s=poll_interval_s) self._worker_rate_limit = worker_rate_limit super().__init__(app_import_path=app_import_path) @@ -226,20 +225,13 @@ async def _main_loop(self): "[%s] Received task to from main worker [message=%s, channel=%s]", *(self._pid, message, channel), ) - task_info = json.loads(message["data"]) - task_args = task_info["args"] - task_kwargs = task_info["kwargs"] - task_id: str = task_info["task_id"] - task_func_name: str = task_id.split(":")[0].split(".")[-1] - task: "Task" = getattr(self.app, task_func_name) + task_serialized: str = message["data"] + task: "Task" = Serialization.deserialize(Task, task_serialized) # Fire and forget: execute the task and publish result task_asyncio: "asyncio.Task" = self._execute_task_and_publish( pubsub=pubsub, task=task, - task_args=task_args, - task_kwargs=task_kwargs, - task_id=task_id, semaphore=semaphore, ) asyncio.create_task(task_asyncio) @@ -248,31 +240,64 @@ async def _execute_task_and_publish( self, pubsub: IPubSub, task: "Task", - task_args: list, - task_kwargs: dict, - task_id: str, semaphore: t.Optional["asyncio.Semaphore"], ): self._logger.debug( "[%s] Executing task %s(*%s, **%s)", - *(self._pid, task_id, task_args, task_kwargs), + *(self._pid, task.id, task.args, task.kwargs), ) - if inspect.iscoroutinefunction(task.func): - task_result = await task(*task_args, **task_kwargs) - else: - task_result = task(*task_args, **task_kwargs) - # Publish the task return value - self._logger.debug( - "[%s] Publishing task result %s(*%s, **%s)", - *(self._pid, task_id, task_args, task_kwargs), - ) - task_result = json.dumps(task_result) - result_channel = RESULTS_CHANNEL_TEMPLATE.format(task_id=task_id) - await pubsub.publish(channel=result_channel, message=task_result) + retry = False + error = None + retries: int = None + retry_max: int | None = None + task_result: t.Any = None + try: + if inspect.iscoroutinefunction(task.func): + task_result = await task(*task.args, **task.kwargs) + else: + task_result = task(*task.args, **task.kwargs) + except Exception as e: # pylint: disable=broad-except + error = e + if task.retry is not None: + retry_max = task.retry["max_retries"] + if isinstance(e, task.retry["on"]): + retry = True + + finally: + # Retry if still within retry limit + if retry: + async with redis.from_url(url="redis://localhost:6379") as redis_client: + retries = int(await redis_client.get(f"retry:{task.id}") or 0) + if retry_max is not None and retries < retry_max: + retries += 1 + logger.debug( + "Task %s[%s] failed on exception %s, will retry (%s/%s)", + *(task.__qualname__, task.id, error, retries, retry_max), + ) + asyncio.create_task(task.publish()) + await redis_client.set(f"retry:{task.id}", retries) + if semaphore is not None: + semaphore.release() + return # pylint: disable=lost-exception + + if error: + # Publish error + logger.debug("Publishing error") + result = AsyncResult(task_id=task.id, ready=True, result=None, error=error) + else: + # Publish the task return value + self._logger.debug( + "[%s] Publishing task result %s(*%s, **%s)", + *(self._pid, task.id, task.args, task.kwargs), + ) + result = AsyncResult(task_id=task.id, ready=True, result=task_result, error=None) + task_serialized = Serialization.serialize(obj=result) + result_channel = Config.results_channel_template().format(task_id=task.id) + await pubsub.publish(channel=result_channel, message=task_serialized) - if semaphore is not None: - semaphore.release() + if semaphore is not None: + semaphore.release() def validate_input(app_import_path: str) -> t.Optional[str]: diff --git a/src/tests/apps/simple_app.py b/src/tests/apps/simple_app.py index 2325d6b..df85bc3 100644 --- a/src/tests/apps/simple_app.py +++ b/src/tests/apps/simple_app.py @@ -1,36 +1,122 @@ import asyncio +import logging +import os import aiotaskq +logger = logging.getLogger(__name__) -@aiotaskq.task + +class SomeException(Exception): + pass + + +class SomeException2(Exception): + pass + + +@aiotaskq.task( + options={ + "retry": { + "max_retries": 2, + "on": (SomeException,), + }, + }, +) +async def append_to_file(filename: str) -> None: + """ + Append pid to file for testing purpose and unconditionally raise SomeException. + + How is this useful for testing? + + This task is defined with retry options. By appending to a file and raising exception at the + end, we can check how many times this task has been applied and verify if the retry logic is + working. + """ + _append_pid_to_file(filename=filename) + raise SomeException("Some error") + + +@aiotaskq.task() +async def append_to_file_2(filename: str) -> None: + """ + Append pid to file for testing purpose and unconditionally raise SomeException2. + + This works exactly the same as `append_to_file` except that: + 1. We're raising a different exception. + 2. The task is not defined with the retry options + + This can help us verify that: + 1. The retry logic can also be provided during task call instead of only during task + defintion + 2. The retry logic is applied only against the specified exception classes + """ + _append_pid_to_file(filename=filename) + raise SomeException2 + + +@aiotaskq.task( + options={ + "retry": { + "max_retries": 2, + "on": (SomeException,), + }, + }, +) +async def append_to_file_first_3_times_with_error(filename: str) -> None: + """ + Append pid to file for testing purpose and *conditionally* raise SomeException. + + This works exactly the same as `append_to_file` except that we're raising + SomeException only on a certain condition. + + This can help us verify that a task will only be retried while it fails -- once + it successfully run without error, it will no longer be retried. + """ + _append_pid_to_file(filename=filename) + # If func has been called <= 2 times (file has <= 2 lines), raise SomeException2. + # Else, return without error. + with open(filename, mode="r", encoding="utf-8") as fi: + num_lines = len(fi.read().rstrip("\n").split("\n")) + if num_lines <= 2: + raise SomeException + + +def _append_pid_to_file(filename: str) -> None: + content: str = str(os.getpid()) + with open(filename, mode="a", encoding="utf-8") as fo: + fo.write(content + "\n") + fo.flush() + + +@aiotaskq.task() def echo(x): return x -@aiotaskq.task +@aiotaskq.task() async def wait(t_s: int) -> int: """Wait asynchronously for `t_s` seconds.""" await asyncio.sleep(t_s) return t_s -@aiotaskq.task +@aiotaskq.task() def add(x: int, y: int) -> int: return x + y -@aiotaskq.task +@aiotaskq.task() def power(a: int, b: int = 1) -> int: return a**b -@aiotaskq.task +@aiotaskq.task() def join(ls: list, delimiter: str = ",") -> str: return delimiter.join([str(x) for x in ls]) -@aiotaskq.task +@aiotaskq.task() def some_task(b: int) -> int: # Some task with high cpu usage def _naive_fib(n: int) -> int: diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 1814564..be8c9fa 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -7,6 +7,7 @@ import pytest from aiotaskq.interfaces import ConcurrencyType +from aiotaskq.concurrency_manager import ConcurrencyManagerSingleton from aiotaskq.worker import Defaults, run_worker_forever @@ -21,6 +22,9 @@ async def start( worker_rate_limit: int = Defaults.worker_rate_limit(), poll_interval_s: t.Optional[float] = Defaults.poll_interval_s(), ) -> None: + # Reset singleton so each test is isolated + ConcurrencyManagerSingleton.reset() + proc = multiprocessing.Process( target=lambda: run_worker_forever( app_import_path=app, @@ -66,3 +70,11 @@ def worker(): yield worker_ worker_.terminate() worker_.close() + + +@pytest.fixture +def some_file(): + filename = "./some_file.txt" + yield filename + if os.path.exists(filename): + os.remove(filename) diff --git a/src/tests/test_cli.py b/src/tests/test_cli.py index da32eb9..2dea7da 100644 --- a/src/tests/test_cli.py +++ b/src/tests/test_cli.py @@ -8,6 +8,7 @@ from aiotaskq import __version__ from aiotaskq.__main__ import _version_callback +from aiotaskq.worker import Defaults def test_root_show_proper_help_message(): @@ -59,6 +60,8 @@ def test_version(): def test_worker_show_proper_help_message(): bash_command = "aiotaskq worker --help" + default_cpu_count: int = multiprocessing.cpu_count() + default_poll_interval_s: float = Defaults.poll_interval_s() with os.popen(bash_command) as pipe: output = pipe.read() output_expected = ( @@ -70,8 +73,8 @@ def test_worker_show_proper_help_message(): " APP [required]\n" "\n" "Options:\n" - f" --concurrency INTEGER [default: {multiprocessing.cpu_count()}]\n" - " --poll-interval-s FLOAT [default: 0.01]\n" + f" --concurrency INTEGER [default: {default_cpu_count}]\n" + f" --poll-interval-s FLOAT [default: {default_poll_interval_s}]\n" " --concurrency-type [multiprocessing]\n" " [default: multiprocessing]\n" " --worker-rate-limit INTEGER [default: -1]\n" diff --git a/src/tests/test_concurrency.py b/src/tests/test_concurrency.py index b700a88..934bfe0 100644 --- a/src/tests/test_concurrency.py +++ b/src/tests/test_concurrency.py @@ -6,7 +6,7 @@ from tests.apps import simple_app -if t.TYPE_CHECKING: # pragma: no cover +if t.TYPE_CHECKING: from tests.conftest import WorkerFixture @@ -84,7 +84,9 @@ async def test_concurrent_async_tasks_return_correctly(worker: "WorkerFixture"): @pytest.mark.asyncio -async def test_async__concurrency_and_worker_rate_limit_of_1__effectively_serial(worker: "WorkerFixture"): +async def test_async__concurrency_and_worker_rate_limit_of_1__effectively_serial( + worker: "WorkerFixture", +): """Assert that if concurrency=1 & worker-rate-limit=1, tasks will effectively run serially.""" # Given that the worker cli is run with "--concurrency 1" and "--worker-rate-limit 1" options await worker.start(app=simple_app.__name__, concurrency=1, worker_rate_limit=1) diff --git a/src/tests/test_concurrency_manager.py b/src/tests/test_concurrency_manager.py index c265b94..dba7607 100644 --- a/src/tests/test_concurrency_manager.py +++ b/src/tests/test_concurrency_manager.py @@ -1,5 +1,6 @@ from aiotaskq.exceptions import ConcurrencyTypeNotSupported from aiotaskq.concurrency_manager import ConcurrencyManagerSingleton +from aiotaskq.interfaces import ConcurrencyType def test_unsupported_concurrency_type(): @@ -18,3 +19,18 @@ def test_unsupported_concurrency_type(): assert ( str(error) == 'Concurrency type "some-incorrect-concurrency-type" is not yet supported.' ) + + +def test_singleton(): + # When getting the concurrency_manager instance more than once + instance_1 = ConcurrencyManagerSingleton.get( + concurrency_type=ConcurrencyType.MULTIPROCESSING, + concurrency=4, + ) + instance_2 = ConcurrencyManagerSingleton.get( + concurrency_type=ConcurrencyType.MULTIPROCESSING, + concurrency=4, + ) + + # Then the both instances should be the identical instance + assert instance_1 is instance_2 diff --git a/src/tests/test_integration.py b/src/tests/test_integration.py index bbc5574..e799237 100644 --- a/src/tests/test_integration.py +++ b/src/tests/test_integration.py @@ -13,7 +13,7 @@ async def test_sync_and_async_parity__simple_app(worker: WorkerFixture): await worker.start(app=app.__name__, concurrency=8) # Then there should be parity between sync and async call of the tasks tests: list[tuple[Task, tuple, dict]] = [ - (simple_app.wait, tuple() ,{"t_s": 1}), + (simple_app.wait, tuple(), {"t_s": 1}), (simple_app.echo, (42,), {}), (simple_app.add, tuple(), {"x": 41, "y": 1}), (simple_app.power, (2,), {"b": 64}), diff --git a/src/tests/test_serde.py b/src/tests/test_serde.py new file mode 100644 index 0000000..313a185 --- /dev/null +++ b/src/tests/test_serde.py @@ -0,0 +1,80 @@ +from importlib import import_module +import json + +from aiotaskq.serde import JsonTaskSerialization +from aiotaskq.task import Task, task + + +@task() +def some_task(a: int, b: int) -> int: + return a * b + + +class SomeException(Exception): + pass + + +@task( + options={ + "retry": {"max_retries": 1, "on": (SomeException,)}, + }, +) +def some_task_2(a: int, b: int) -> int: + return a * b + + +def test_serialize_task_to_json(): + # Given a task definition + assert isinstance(some_task, Task) + + # When the task has been serialized to json + task_serialized = JsonTaskSerialization.serialize(some_task) + + # Then the resulting json should be in correct type and format + assert isinstance(task_serialized, bytes) + # And should able to deserialized into the same task + task_deserialized = JsonTaskSerialization.deserialize(Task, task_serialized) + task_format_str, task_serialized_str = task_serialized.decode("utf-8").split("|", 1) + assert task_format_str == "json" + # And the task should be serialized into correct json + task_serialized_dict = json.loads(task_serialized_str) + assert task_serialized_dict == { + "func": {"module": "tests.test_serde", "qualname": "some_task"}, + "task_id": None, + "args": None, + "kwargs": None, + "options": {}, + } + # And should be functionally the same as the original task + assert task_deserialized.func(1, 2) == some_task.func(1, 2) + assert task_deserialized(3, 4) == some_task(3, 4) + + +def test_serialize_task_to_json__with_retry_param(): + # Given a task definition + assert isinstance(some_task_2, Task) + + # When the task has been serialized to json + task_serialized = JsonTaskSerialization.serialize(some_task_2) + + # Then the task should be serialized in the correct type and format + assert isinstance(task_serialized, bytes) + task_format_str, task_serialized_str = task_serialized.decode("utf-8").split("|", 1) + assert task_format_str == "json" + # And the serialized task should contain information about retry + task_serialized_dict = json.loads(task_serialized_str) + assert task_serialized_dict == { + "func": {"module": "tests.test_serde", "qualname": "some_task_2"}, + "task_id": None, + "args": None, + "kwargs": None, + "options": { + "retry": { + "max_retries": 1, + "on": '{"py/tuple": [{"py/type": "tests.test_serde.SomeException"}]}', + }, + }, + } + # And the deserialized task should function the same as the original + task_deserialized = import_module(task_serialized_dict["func"]["module"]).some_task_2 + assert task_deserialized(2, 3) == some_task_2(2, 3) diff --git a/src/tests/test_task.py b/src/tests/test_task.py index 171a16c..c350faf 100644 --- a/src/tests/test_task.py +++ b/src/tests/test_task.py @@ -1,3 +1,4 @@ +import os from typing import TYPE_CHECKING import pytest @@ -5,7 +6,7 @@ from aiotaskq.exceptions import InvalidArgument from tests.apps import simple_app -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from aiotaskq.task import Task from tests.conftest import WorkerFixture @@ -50,6 +51,176 @@ async def test_invalid_argument_provided_to_apply_async( error = exc finally: assert str(error) == ( - f"These arguments are invalid: args={invalid_args}," - f" kwargs={invalid_kwargs}" + f"These arguments are invalid: args={invalid_args}," f" kwargs={invalid_kwargs}" ) + + +@pytest.mark.asyncio +async def test_retry_as_per_task_definition(worker: "WorkerFixture", some_file: str): + # Given a worker running in the background + await worker.start(simple_app.__name__, concurrency=1) + # And a task defined with retry configuration + assert simple_app.append_to_file.retry["max_retries"] == 2 + assert simple_app.append_to_file.retry["on"] == (simple_app.SomeException,) + # And the task will raise an exception when called as a function + exception = None + try: + await simple_app.append_to_file(some_file) + except simple_app.SomeException as e: + exception = e + finally: + if os.path.exists(some_file): + os.remove(some_file) + assert isinstance(exception, simple_app.SomeException) + + # When the task has been applied + exception = None + try: + await simple_app.append_to_file.apply_async(filename=some_file) + except simple_app.SomeException as exc: + exception = exc + finally: + # Then the task should be retried as many times as configured + with open(some_file, encoding="utf-8") as fi: + assert len(fi.readlines()) == 1 + 2, f"file: {fi.readlines()}" # First call + 2 retries + assert isinstance(exception, simple_app.SomeException) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "retry_on,retries_expected", + [ + ( + ( + simple_app.SomeException, + simple_app.SomeException2, + ), + 1, + ), + ( + ( + simple_app.SomeException, + simple_app.SomeException2, + ), + 2, + ), + ], +) +async def test_retry_as_per_task_call( + worker: "WorkerFixture", + retry_on: tuple[type[Exception], ...], + retries_expected: int, + some_file: str, +): + # Given a worker running in the background + await worker.start(simple_app.__name__, concurrency=1) + # And a task defined WITHOUT retry configuration + assert simple_app.append_to_file_2.retry is None + # And the task will raise an exception when called + exception = None + try: + await simple_app.append_to_file_2.func(some_file) + except simple_app.SomeException2 as e: + exception = e + finally: + assert isinstance(exception, simple_app.SomeException2) + if os.path.exists(some_file): + os.remove(some_file) + + # When the task is applied with retry option + exception = None + try: + await simple_app.append_to_file_2.with_retry( + max_retries=retries_expected, + on=retry_on, + ).apply_async(filename=some_file) + except simple_app.SomeException2 as e: + exception = e + finally: + # Then the task should be retried as many times as requested + with open(some_file, encoding="utf-8") as fi: + assert ( + len(fi.readlines()) == 1 + retries_expected + ) # First call + `retries_expected` retries + # And the task should fail with the expected exception + assert isinstance(exception, simple_app.SomeException2) + + +@pytest.mark.asyncio +async def test_no_retry_as_per_task_call(worker: "WorkerFixture", some_file: str): + # Given a worker running in the background + await worker.start(simple_app.__name__, concurrency=1) + # And a task defined WITH retry configuration + assert simple_app.append_to_file.retry is not None + # And the task will raise an exception when called + exception = None + try: + await simple_app.append_to_file.func(some_file) + except simple_app.SomeException as e: + exception = e + finally: + assert isinstance(exception, simple_app.SomeException) + if os.path.exists(some_file): + os.remove(some_file) + + # When the task is applied with retry option + # where retry["on"] doesn't include the exception that the task raises + exception = None + retry_on_new = (simple_app.SomeException2,) + assert [exc not in simple_app.append_to_file.retry["on"] for exc in retry_on_new] + try: + await simple_app.append_to_file.with_retry( + max_retries=2, + on=retry_on_new, + ).apply_async(filename=some_file) + except simple_app.SomeException as e: + exception = e + finally: + # Then the task should NOT be retried + with open(some_file, encoding="utf-8") as fi: + assert len(fi.readlines()) == 1 + # And the task should fail with the expected exception + assert isinstance(exception, simple_app.SomeException) + + +@pytest.mark.asyncio +async def test_retry_until_successful(worker: "WorkerFixture", some_file: str): + """Assert that task will stop being retried once it's successfully executed without error.""" + # Given a worker running in the background + await worker.start(simple_app.__name__, concurrency=1) + # And a task defined WITH retry max_retries = 2 + assert simple_app.append_to_file_first_3_times_with_error.retry["max_retries"] == 2 + # And the task will raise an exception when called until it's called for the 3rd time + exception = None + try: + # Call for the first time + await simple_app.append_to_file_first_3_times_with_error.func(some_file) + except simple_app.SomeException as e: + exception = e + finally: + assert exception is not None + exception = None + try: + # Call for the second time + await simple_app.append_to_file_first_3_times_with_error.func(some_file) + except simple_app.SomeException as e: + exception = e + finally: + assert exception is not None + # Call for the third time (Should no longer raise error) + await simple_app.append_to_file_first_3_times_with_error.func(some_file) + with open(some_file, mode="r", encoding="utf-8") as fi: + num_lines = len(fi.read().rstrip("\n").split("\n")) + assert num_lines == 3 + # Delete the file + if os.path.exists(some_file): + os.remove(some_file) + + # When the task is applied + await simple_app.append_to_file_first_3_times_with_error.apply_async(filename=some_file) + + # Then the task should applied successfully with no error + # After having been retried 2 times + with open(some_file, mode="r", encoding="utf-8") as fi: + num_lines = len(fi.read().rstrip("\n").split("\n")) + assert num_lines == 3 # (first call + 2 retries) diff --git a/src/tests/test_worker.py b/src/tests/test_worker.py index 55d9a90..22b8d6d 100644 --- a/src/tests/test_worker.py +++ b/src/tests/test_worker.py @@ -8,7 +8,7 @@ from aiotaskq.interfaces import ConcurrencyType from aiotaskq.worker import validate_input -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: from tests.conftest import WorkerFixture diff --git a/test.sh b/test.sh index 5fa3b6b..b703925 100755 --- a/test.sh +++ b/test.sh @@ -1,17 +1,16 @@ -pip install --upgrade pip -pip install -e .[dev] - coverage erase if [ -z $1 ]; then + pip install --upgrade pip + pip install -e .[dev] coverage run -m pytest -v else - coverage run -m pytest -v -k $1 + coverage run -m pytest --log-level=DEBUG --log-cli-level=DEBUG -s -vvv -k $1 fi failed=$? -coverage combine +coverage combine --quiet exit $failed From af063386698de3a3138ac4eb04c4b4e80dc63d28 Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sun, 10 Mar 2024 12:05:08 -0400 Subject: [PATCH 2/8] Rename constants.py to config.py --- src/aiotaskq/__main__.py | 2 +- src/aiotaskq/{constants.py => config.py} | 0 src/aiotaskq/serde.py | 2 +- src/aiotaskq/task.py | 2 +- src/aiotaskq/worker.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename src/aiotaskq/{constants.py => config.py} (100%) diff --git a/src/aiotaskq/__main__.py b/src/aiotaskq/__main__.py index d21a489..b672b8d 100755 --- a/src/aiotaskq/__main__.py +++ b/src/aiotaskq/__main__.py @@ -8,7 +8,7 @@ import typer from . import __version__ -from .constants import Config +from .config import Config from .interfaces import ConcurrencyType from .worker import Defaults, run_worker_forever diff --git a/src/aiotaskq/constants.py b/src/aiotaskq/config.py similarity index 100% rename from src/aiotaskq/constants.py rename to src/aiotaskq/config.py diff --git a/src/aiotaskq/serde.py b/src/aiotaskq/serde.py index 5782325..ecd4269 100644 --- a/src/aiotaskq/serde.py +++ b/src/aiotaskq/serde.py @@ -9,7 +9,7 @@ import jsonpickle -from .constants import Config +from .config import Config from .interfaces import ISerialization, SerializationType, T from .task import AsyncResult, Task diff --git a/src/aiotaskq/task.py b/src/aiotaskq/task.py index 572d815..9e6dee1 100644 --- a/src/aiotaskq/task.py +++ b/src/aiotaskq/task.py @@ -9,7 +9,7 @@ import typing as t import uuid -from .constants import Config +from .config import Config from .exceptions import InvalidArgument, ModuleInvalidForTask from .interfaces import IPubSub, PollResponse, TaskOptions from .pubsub import PubSub diff --git a/src/aiotaskq/worker.py b/src/aiotaskq/worker.py index d1b975b..4ccc1d8 100755 --- a/src/aiotaskq/worker.py +++ b/src/aiotaskq/worker.py @@ -16,7 +16,7 @@ import aioredis as redis from .concurrency_manager import ConcurrencyManagerSingleton -from .constants import Config +from .config import Config from .interfaces import ConcurrencyType, IConcurrencyManager, IPubSub from .pubsub import PubSub from .serde import Serialization From ce851d7b9240970fb844d4d4b058dc6eb917441f Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sun, 10 Mar 2024 12:36:06 -0400 Subject: [PATCH 3/8] Split file into config.py & constants.py --- src/aiotaskq/config.py | 14 +------------- src/aiotaskq/constants.py | 14 ++++++++++++++ src/aiotaskq/task.py | 5 +++-- src/aiotaskq/worker.py | 7 ++++--- 4 files changed, 22 insertions(+), 18 deletions(-) create mode 100644 src/aiotaskq/constants.py diff --git a/src/aiotaskq/config.py b/src/aiotaskq/config.py index 6aa6f9b..48bd399 100644 --- a/src/aiotaskq/config.py +++ b/src/aiotaskq/config.py @@ -14,8 +14,6 @@ from .interfaces import SerializationType _REDIS_URL = "redis://127.0.0.1:6379" -_TASKS_CHANNEL = "channel:tasks" -_RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" class Config: @@ -31,7 +29,7 @@ class Config: @staticmethod def serialization_type() -> SerializationType: """Return the serialization type as provided via env var AIOTASKQ_SERIALIZATION.""" - s: str | None = environ.get("AIOTASKQ_SERIALIZATION", SerializationType.DEFAULT.value) + s: str = environ.get("AIOTASKQ_SERIALIZATION", SerializationType.DEFAULT.value) return SerializationType[s.upper()] @staticmethod @@ -49,13 +47,3 @@ def broker_url() -> str: """ broker_url: str = environ.get("BROKER_URL", _REDIS_URL) return broker_url - - @staticmethod - def tasks_channel() -> str: - """Return the channel name used for transporting task requests on the broker.""" - return _TASKS_CHANNEL - - @staticmethod - def results_channel_template() -> str: - """Return the template chnnale name used for transporting task results on the broker.""" - return _RESULTS_CHANNEL_TEMPLATE diff --git a/src/aiotaskq/constants.py b/src/aiotaskq/constants.py new file mode 100644 index 0000000..95e06ea --- /dev/null +++ b/src/aiotaskq/constants.py @@ -0,0 +1,14 @@ +_TASKS_CHANNEL = "channel:tasks" +_RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" + + +class Constants: + @staticmethod + def tasks_channel() -> str: + """Return the channel name used for transporting task requests on the broker.""" + return _TASKS_CHANNEL + + @staticmethod + def results_channel_template() -> str: + """Return the template chnnale name used for transporting task results on the broker.""" + return _RESULTS_CHANNEL_TEMPLATE diff --git a/src/aiotaskq/task.py b/src/aiotaskq/task.py index 9e6dee1..235627e 100644 --- a/src/aiotaskq/task.py +++ b/src/aiotaskq/task.py @@ -10,6 +10,7 @@ import uuid from .config import Config +from .constants import Constants from .exceptions import InvalidArgument, ModuleInvalidForTask from .interfaces import IPubSub, PollResponse, TaskOptions from .pubsub import PubSub @@ -53,7 +54,7 @@ async def from_publisher(cls, task_id: str) -> "AsyncResult": pubsub_ = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager - await pubsub.subscribe(Config.results_channel_template().format(task_id=task_id)) + await pubsub.subscribe(Constants.results_channel_template().format(task_id=task_id)) message: PollResponse = await pubsub.poll() logger.debug("Message: %s", message) @@ -189,7 +190,7 @@ async def publish(self) -> RT: ) async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager logger.debug("Publishing task [task_id=%s, message=%s]", self.id, message) - await pubsub.publish(Config.tasks_channel(), message=message) + await pubsub.publish(Constants.tasks_channel(), message=message) async def _get_result(self) -> RT: logger.debug("Retrieving result for task [task_id=%s]", self.id) diff --git a/src/aiotaskq/worker.py b/src/aiotaskq/worker.py index 4ccc1d8..942821c 100755 --- a/src/aiotaskq/worker.py +++ b/src/aiotaskq/worker.py @@ -17,6 +17,7 @@ from .concurrency_manager import ConcurrencyManagerSingleton from .config import Config +from .constants import Constants from .interfaces import ConcurrencyType, IConcurrencyManager, IPubSub from .pubsub import PubSub from .serde import Serialization @@ -72,7 +73,7 @@ def _pid(self) -> int: @staticmethod def _get_child_worker_tasks_channel(pid: int) -> str: - return f"{Config.tasks_channel()}:{pid}" + return f"{Constants.tasks_channel()}:{pid}" class Defaults: @@ -154,7 +155,7 @@ async def _main_loop(self): async with self.pubsub as pubsub: # pylint: disable=not-async-context-manager counter = -1 - await pubsub.subscribe(Config.tasks_channel()) + await pubsub.subscribe(Constants.tasks_channel()) while True: self._logger.debug("[%s] Polling for a new task until it's available", self._pid) message = await pubsub.poll() @@ -293,7 +294,7 @@ async def _execute_task_and_publish( ) result = AsyncResult(task_id=task.id, ready=True, result=task_result, error=None) task_serialized = Serialization.serialize(obj=result) - result_channel = Config.results_channel_template().format(task_id=task.id) + result_channel = Constants.results_channel_template().format(task_id=task.id) await pubsub.publish(channel=result_channel, message=task_serialized) if semaphore is not None: From a4612ce09b94824176ff90fb334a62192269f36d Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sun, 10 Mar 2024 12:43:15 -0400 Subject: [PATCH 4/8] Fixup prev --- src/aiotaskq/config.py | 6 ++---- src/aiotaskq/constants.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/aiotaskq/config.py b/src/aiotaskq/config.py index 48bd399..1e5af18 100644 --- a/src/aiotaskq/config.py +++ b/src/aiotaskq/config.py @@ -1,11 +1,10 @@ """ -Module to define and store all constants used across the library. +Module to define and store all configuration values used across the library. The public object from this module is `Config`. This object wraps -all the constants, which include: +all the configuration values, which include: - Variables - Environment variables -- Static methods that return constant values """ import logging @@ -23,7 +22,6 @@ class Config: These include: - Variables - Environment variables - - Static methods that return constant values """ @staticmethod diff --git a/src/aiotaskq/constants.py b/src/aiotaskq/constants.py index 95e06ea..75372e6 100644 --- a/src/aiotaskq/constants.py +++ b/src/aiotaskq/constants.py @@ -1,8 +1,24 @@ +""" +Module to define and store all constants used across the library. + +The public object from this module is `Constants`. This object wraps +all the constants, which include: +- Static methods that return constant values +""" + + _TASKS_CHANNEL = "channel:tasks" _RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" class Constants: + """ + Provide all the constants. + + These include: + - Static methods that return constant values + """ + @staticmethod def tasks_channel() -> str: """Return the channel name used for transporting task requests on the broker.""" From a21c2513de88d4f35e085938d601d7a079fca1e8 Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sun, 10 Mar 2024 12:55:28 -0400 Subject: [PATCH 5/8] Remove unused attribute from AsyncResult --- src/aiotaskq/task.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/aiotaskq/task.py b/src/aiotaskq/task.py index 235627e..cd1d9f9 100644 --- a/src/aiotaskq/task.py +++ b/src/aiotaskq/task.py @@ -12,7 +12,7 @@ from .config import Config from .constants import Constants from .exceptions import InvalidArgument, ModuleInvalidForTask -from .interfaces import IPubSub, PollResponse, TaskOptions +from .interfaces import PollResponse, TaskOptions from .pubsub import PubSub if t.TYPE_CHECKING: @@ -31,7 +31,6 @@ class AsyncResult(t.Generic[RT]): To get the result of corresponding task, use `.get()`. """ - pubsub: IPubSub task_id: str ready: bool = False result: RT | None @@ -45,7 +44,6 @@ def __init__( self.ready = ready self.result = result self.error = error - self.pubsub = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) @classmethod async def from_publisher(cls, task_id: str) -> "AsyncResult": @@ -170,7 +168,7 @@ async def apply_async(self, *args: P.args, **kwargs: P.kwargs) -> RT: await task_.publish() return await task_._get_result() - async def publish(self) -> RT: + async def publish(self) -> None: """ Publish the task. From 5bd0afa426d1897f87f5c7d3701919f7944086db Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sun, 10 Mar 2024 13:49:21 -0400 Subject: [PATCH 6/8] Avoid retrying forever --- src/aiotaskq/worker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/aiotaskq/worker.py b/src/aiotaskq/worker.py index 942821c..6dcc7ce 100755 --- a/src/aiotaskq/worker.py +++ b/src/aiotaskq/worker.py @@ -264,12 +264,16 @@ async def _execute_task_and_publish( retry_max = task.retry["max_retries"] if isinstance(e, task.retry["on"]): retry = True + # Set retry to 0 if first time + async with redis.from_url(url=Config.broker_url()) as redis_client: + if await redis_client.get(f"retry:{task.id}") is None: + await redis_client.set(f"retry:{task.id}", 0) finally: # Retry if still within retry limit if retry: - async with redis.from_url(url="redis://localhost:6379") as redis_client: - retries = int(await redis_client.get(f"retry:{task.id}") or 0) + async with redis.from_url(url=Config.broker_url()) as redis_client: + retries = int(await redis_client.get(f"retry:{task.id}")) if retry_max is not None and retries < retry_max: retries += 1 logger.debug( From ed611a8a8ba6ae9fe1c1f3344c68e9c1b65e5477 Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sun, 28 Jul 2024 09:55:16 -0400 Subject: [PATCH 7/8] Handle case when `options.retry.on` is empty --- src/aiotaskq/exceptions.py | 4 ++++ src/aiotaskq/interfaces.py | 10 ++++++--- src/aiotaskq/task.py | 12 ++++++++--- src/tests/test_task.py | 42 +++++++++++++++++++++++++++++++++++++- 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/aiotaskq/exceptions.py b/src/aiotaskq/exceptions.py index 2a5037c..c0fac70 100644 --- a/src/aiotaskq/exceptions.py +++ b/src/aiotaskq/exceptions.py @@ -19,3 +19,7 @@ class ConcurrencyTypeNotSupported(Exception): class InvalidArgument(Exception): """A task is applied with invalid arguments.""" + + +class InvalidRetryOptions(Exception): + """A task is defined with invalid retry options.""" diff --git a/src/aiotaskq/interfaces.py b/src/aiotaskq/interfaces.py index bc3d2a8..70f92b3 100644 --- a/src/aiotaskq/interfaces.py +++ b/src/aiotaskq/interfaces.py @@ -168,7 +168,8 @@ class RetryOptions(t.TypedDict): max_retries int | None: The number times to keep retrying the execution of the task until the task executes successfully. Counting starts from 0 so if max_retries = 2 for example, then the task will execute - 1 + 2 times (1 time for first execution, 2 times for re-try). + 1 + 2 times (1 time for first execution, 2 times for re-try) in the + worst case scenario. on tuple[type[Exception], ...]: The tuple of exception classes to retry on. The task will will only be retried if that exception that is raised during task execution is an instance of one of the listed @@ -176,11 +177,14 @@ class RetryOptions(t.TypedDict): Examples: - If on=(Exception,) then any kind of exception will trigger + If `on=(Exception,)` then any kind of exception will trigger a retry. - If on=(ExceptionA, ExceptionB,) and during task + If `on=(ExceptionA, ExceptionB,)` and during task execution ExceptionC was raised, then retry is not triggered. + + If `on=tuple()` then during task definition aiotaskq will raise + `InvalidRetryOptions` """ max_retries: int | None diff --git a/src/aiotaskq/task.py b/src/aiotaskq/task.py index cd1d9f9..a0b2add 100644 --- a/src/aiotaskq/task.py +++ b/src/aiotaskq/task.py @@ -11,7 +11,7 @@ from .config import Config from .constants import Constants -from .exceptions import InvalidArgument, ModuleInvalidForTask +from .exceptions import InvalidArgument, InvalidRetryOptions, ModuleInvalidForTask from .interfaces import PollResponse, TaskOptions from .pubsub import PubSub @@ -113,7 +113,11 @@ def __init__( Store the underlying function and an automatically generated task_id in the Task instance. """ self.func = func + + if retry and len(retry.get("on", [])) == 0: + raise InvalidRetryOptions('retry.on should not be empty') self.retry = retry + self.args = args self.kwargs = kwargs self.id = task_id @@ -134,8 +138,10 @@ def with_retry(self, max_retries: int, on: tuple[type[Exception], ...]) -> "Task We return a copy so that we don't overwrite the original task definition. """ - task_ = copy.deepcopy(self) - retry = {"max_retries": max_retries, "on": on} + task_: Task = copy.deepcopy(self) + if len(on) == 0: + raise InvalidRetryOptions + retry: RetryOptions = {"max_retries": max_retries, "on": on} task_.retry = retry return task_ diff --git a/src/tests/test_task.py b/src/tests/test_task.py index c350faf..a75304c 100644 --- a/src/tests/test_task.py +++ b/src/tests/test_task.py @@ -3,7 +3,8 @@ import pytest -from aiotaskq.exceptions import InvalidArgument +from aiotaskq.task import task as task_decorator +from aiotaskq.exceptions import InvalidArgument, InvalidRetryOptions from tests.apps import simple_app if TYPE_CHECKING: @@ -224,3 +225,42 @@ async def test_retry_until_successful(worker: "WorkerFixture", some_file: str): with open(some_file, mode="r", encoding="utf-8") as fi: num_lines = len(fi.read().rstrip("\n").split("\n")) assert num_lines == 3 # (first call + 2 retries) + + +def test_empty_retry_on_during_task_definition__invalid(): + # When a task is defined with options.retry.on = tuple() + exception = None + try: + @task_decorator( + options={ + "retry": { + "on": tuple(), + "max_retries": 1, + }, + }, + ) + def _(): + return "Hello world" + except Exception as e: # pylint: disable=broad-except + exception = e + finally: + # Then InvalidRetryOptions should be raised during task definition + assert isinstance(exception, InvalidRetryOptions), ( + "Task definition should fail with InvalidRetryOptions" + ) + + +@pytest.mark.asyncio +async def test_empty_retry_on_during_task_call__invalid(some_file: str): + # Give a task that is defined without error + some_task = simple_app.append_to_file + + exception = None + try: + # When the task is called with options.retry.on = empty tuple + await some_task.with_retry(max_retries=1, on=tuple()).apply_async(some_file) + except Exception as e: # pylint: disable=broad-except + exception = e + finally: + # Then InvalidRetryOptions should be raised during task call + assert isinstance(exception, InvalidRetryOptions), "Task call should fail with InvalidRetryOptions" From 146f8f6cdcd7f49fc7dbc257b55e208108ac37c0 Mon Sep 17 00:00:00 2001 From: Imran Ariffin Date: Sun, 28 Jul 2024 10:06:32 -0400 Subject: [PATCH 8/8] Move logic from `AsyncResult.from_publisher` to `Task._get_result` --- src/aiotaskq/task.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/aiotaskq/task.py b/src/aiotaskq/task.py index a0b2add..cada236 100644 --- a/src/aiotaskq/task.py +++ b/src/aiotaskq/task.py @@ -45,22 +45,6 @@ def __init__( self.result = result self.error = error - @classmethod - async def from_publisher(cls, task_id: str) -> "AsyncResult": - """Return the result of the task once finished.""" - from aiotaskq.serde import Serialization # pylint: disable=import-outside-toplevel - - pubsub_ = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) - async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager - await pubsub.subscribe(Constants.results_channel_template().format(task_id=task_id)) - message: PollResponse = await pubsub.poll() - - logger.debug("Message: %s", message) - - result_serialized: bytes = message["data"] - result_: "AsyncResult" = Serialization.deserialize(cls, result_serialized) - return result_ - def get(self) -> RT | Exception: """Return the result of the task once finished.""" if self.error is not None: @@ -197,8 +181,19 @@ async def publish(self) -> None: await pubsub.publish(Constants.tasks_channel(), message=message) async def _get_result(self) -> RT: + from aiotaskq.serde import Serialization # pylint: disable=import-outside-toplevel + logger.debug("Retrieving result for task [task_id=%s]", self.id) - async_result: AsyncResult[RT] = await AsyncResult.from_publisher(task_id=self.id) + pubsub_ = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) + async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager + await pubsub.subscribe(Constants.results_channel_template().format(task_id=self.id)) + message: PollResponse = await pubsub.poll() + + logger.debug("Message: %s", message) + + result_serialized: bytes = message["data"] + async_result: AsyncResult[RT] = Serialization.deserialize(AsyncResult, result_serialized) + result: RT | Exception = async_result.get() if isinstance(result, Exception): raise result