diff --git a/src/aiotaskq/constants.py b/src/aiotaskq/constants.py index d642297..6aa6f9b 100644 --- a/src/aiotaskq/constants.py +++ b/src/aiotaskq/constants.py @@ -13,11 +13,9 @@ from .interfaces import SerializationType -# TODO @imranariffin: Wrap these inside `Config` -# TODO @imranariffin: Rename REDIS_URL to BROKER_URL -REDIS_URL = "redis://127.0.0.1:6379" -TASKS_CHANNEL = "channel:tasks" -RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" +_REDIS_URL = "redis://127.0.0.1:6379" +_TASKS_CHANNEL = "channel:tasks" +_RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}" class Config: @@ -41,3 +39,23 @@ def log_level() -> int: """Return the log level as provided via env var LOG_LEVEL.""" level: int = int(environ.get("AIOTASKQ_LOG_LEVEL", logging.DEBUG)) return level + + @staticmethod + def broker_url() -> str: + """ + Return the broker url as provided via env var BROKER_URL. + + Defaults to "redis://127.0.0.1:6379". + """ + broker_url: str = environ.get("BROKER_URL", _REDIS_URL) + return broker_url + + @staticmethod + def tasks_channel() -> str: + """Return the channel name used for transporting task requests on the broker.""" + return _TASKS_CHANNEL + + @staticmethod + def results_channel_template() -> str: + """Return the template chnnale name used for transporting task results on the broker.""" + return _RESULTS_CHANNEL_TEMPLATE diff --git a/src/aiotaskq/task.py b/src/aiotaskq/task.py index 0411a04..858689f 100644 --- a/src/aiotaskq/task.py +++ b/src/aiotaskq/task.py @@ -9,7 +9,7 @@ import typing as t import uuid -from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL +from .constants import Config from .exceptions import InvalidArgument, ModuleInvalidForTask from .interfaces import IPubSub, PollResponse, TaskOptions from .pubsub import PubSub @@ -44,16 +44,16 @@ def __init__( self.ready = ready self.result = result self.error = error - self.pubsub = PubSub.get(url=REDIS_URL, poll_interval_s=0.01) + self.pubsub = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) @classmethod async def from_publisher(cls, task_id: str) -> "AsyncResult": """Return the result of the task once finished.""" from aiotaskq.serde import Serialization # pylint: disable=import-outside-toplevel - pubsub_ = PubSub.get(url=REDIS_URL, poll_interval_s=0.01) + pubsub_ = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01) async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager - await pubsub.subscribe(RESULTS_CHANNEL_TEMPLATE.format(task_id=task_id)) + await pubsub.subscribe(Config.results_channel_template().format(task_id=task_id)) message: PollResponse = await pubsub.poll() logger.debug("Message: %s", message) @@ -182,11 +182,14 @@ async def publish(self) -> RT: message: bytes = Serialization.serialize(self) pubsub_ = PubSub.get( - url=REDIS_URL, poll_interval_s=0.01, max_connections=10, decode_responses=True + url=Config.broker_url(), + poll_interval_s=0.01, + max_connections=10, + decode_responses=True, ) async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager logger.debug("Publishing task [task_id=%s, message=%s]", self.id, message) - await pubsub.publish(TASKS_CHANNEL, message=message) + await pubsub.publish(Config.tasks_channel(), message=message) async def _get_result(self) -> RT: logger.debug("Retrieving result for task [task_id=%s]", self.id) diff --git a/src/aiotaskq/worker.py b/src/aiotaskq/worker.py index c51e885..42bbb74 100755 --- a/src/aiotaskq/worker.py +++ b/src/aiotaskq/worker.py @@ -16,7 +16,7 @@ import aioredis as redis from .concurrency_manager import ConcurrencyManagerSingleton -from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL +from .constants import Config from .interfaces import ConcurrencyType, IConcurrencyManager, IPubSub from .pubsub import PubSub from .serde import Serialization @@ -72,7 +72,7 @@ def _pid(self) -> int: @staticmethod def _get_child_worker_tasks_channel(pid: int) -> str: - return f"{TASKS_CHANNEL}:{pid}" + return f"{Config.tasks_channel()}:{pid}" class Defaults: @@ -117,7 +117,7 @@ def __init__( worker_rate_limit: int, poll_interval_s: float, ) -> None: - self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s) + self.pubsub: IPubSub = PubSub.get(url=Config.broker_url(), poll_interval_s=poll_interval_s) self.concurrency_manager: IConcurrencyManager = ConcurrencyManagerSingleton.get( concurrency_type=concurrency_type, concurrency=concurrency, @@ -154,7 +154,7 @@ async def _main_loop(self): async with self.pubsub as pubsub: # pylint: disable=not-async-context-manager counter = -1 - await pubsub.subscribe(TASKS_CHANNEL) + await pubsub.subscribe(Config.tasks_channel()) while True: self._logger.debug("[%s] Polling for a new task until it's available", self._pid) message = await pubsub.poll() @@ -191,7 +191,7 @@ class GruntWorker(BaseWorker): """ def __init__(self, app_import_path: str, poll_interval_s: float, worker_rate_limit: int): - self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s) + self.pubsub: IPubSub = PubSub.get(url=Config.broker_url(), poll_interval_s=poll_interval_s) self._worker_rate_limit = worker_rate_limit super().__init__(app_import_path=app_import_path) @@ -295,7 +295,7 @@ async def _execute_task_and_publish( task_id=task.id, ready=True, result=task_result, error=None ) task_serialized = Serialization.serialize(obj=result) - result_channel = RESULTS_CHANNEL_TEMPLATE.format(task_id=task.id) + result_channel = Config.results_channel_template().format(task_id=task.id) await pubsub.publish(channel=result_channel, message=task_serialized) if semaphore is not None: