Skip to content

Commit

Permalink
Wrap all constants inside Config class
Browse files Browse the repository at this point in the history
  • Loading branch information
imranariffin committed Dec 3, 2023
1 parent 1279c5f commit aaeb820
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
28 changes: 23 additions & 5 deletions src/aiotaskq/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@

from .interfaces import SerializationType

# TODO @imranariffin: Wrap these inside `Config`
# TODO @imranariffin: Rename REDIS_URL to BROKER_URL
REDIS_URL = "redis://127.0.0.1:6379"
TASKS_CHANNEL = "channel:tasks"
RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}"
_REDIS_URL = "redis://127.0.0.1:6379"
_TASKS_CHANNEL = "channel:tasks"
_RESULTS_CHANNEL_TEMPLATE = "channel:results:{task_id}"


class Config:
Expand All @@ -41,3 +39,23 @@ def log_level() -> int:
"""Return the log level as provided via env var LOG_LEVEL."""
level: int = int(environ.get("AIOTASKQ_LOG_LEVEL", logging.DEBUG))
return level

@staticmethod
def broker_url() -> str:
"""
Return the broker url as provided via env var BROKER_URL.
Defaults to "redis://127.0.0.1:6379".
"""
broker_url: str = environ.get("BROKER_URL", _REDIS_URL)
return broker_url

@staticmethod
def tasks_channel() -> str:
"""Return the channel name used for transporting task requests on the broker."""
return _TASKS_CHANNEL

@staticmethod
def results_channel_template() -> str:
"""Return the template chnnale name used for transporting task results on the broker."""
return _RESULTS_CHANNEL_TEMPLATE
15 changes: 9 additions & 6 deletions src/aiotaskq/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import typing as t
import uuid

from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL
from .constants import Config
from .exceptions import InvalidArgument, ModuleInvalidForTask
from .interfaces import IPubSub, PollResponse, TaskOptions
from .pubsub import PubSub
Expand Down Expand Up @@ -44,16 +44,16 @@ def __init__(
self.ready = ready
self.result = result
self.error = error
self.pubsub = PubSub.get(url=REDIS_URL, poll_interval_s=0.01)
self.pubsub = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01)

@classmethod
async def from_publisher(cls, task_id: str) -> "AsyncResult":
"""Return the result of the task once finished."""
from aiotaskq.serde import Serialization # pylint: disable=import-outside-toplevel

pubsub_ = PubSub.get(url=REDIS_URL, poll_interval_s=0.01)
pubsub_ = PubSub.get(url=Config.broker_url(), poll_interval_s=0.01)
async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager
await pubsub.subscribe(RESULTS_CHANNEL_TEMPLATE.format(task_id=task_id))
await pubsub.subscribe(Config.results_channel_template().format(task_id=task_id))
message: PollResponse = await pubsub.poll()

logger.debug("Message: %s", message)
Expand Down Expand Up @@ -182,11 +182,14 @@ async def publish(self) -> RT:
message: bytes = Serialization.serialize(self)

pubsub_ = PubSub.get(
url=REDIS_URL, poll_interval_s=0.01, max_connections=10, decode_responses=True
url=Config.broker_url(),
poll_interval_s=0.01,
max_connections=10,
decode_responses=True,
)
async with pubsub_ as pubsub: # pylint: disable=not-async-context-manager
logger.debug("Publishing task [task_id=%s, message=%s]", self.id, message)
await pubsub.publish(TASKS_CHANNEL, message=message)
await pubsub.publish(Config.tasks_channel(), message=message)

async def _get_result(self) -> RT:
logger.debug("Retrieving result for task [task_id=%s]", self.id)
Expand Down
12 changes: 6 additions & 6 deletions src/aiotaskq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import aioredis as redis

from .concurrency_manager import ConcurrencyManagerSingleton
from .constants import REDIS_URL, RESULTS_CHANNEL_TEMPLATE, TASKS_CHANNEL
from .constants import Config
from .interfaces import ConcurrencyType, IConcurrencyManager, IPubSub
from .pubsub import PubSub
from .serde import Serialization
Expand Down Expand Up @@ -72,7 +72,7 @@ def _pid(self) -> int:

@staticmethod
def _get_child_worker_tasks_channel(pid: int) -> str:
return f"{TASKS_CHANNEL}:{pid}"
return f"{Config.tasks_channel()}:{pid}"


class Defaults:
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
worker_rate_limit: int,
poll_interval_s: float,
) -> None:
self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s)
self.pubsub: IPubSub = PubSub.get(url=Config.broker_url(), poll_interval_s=poll_interval_s)
self.concurrency_manager: IConcurrencyManager = ConcurrencyManagerSingleton.get(
concurrency_type=concurrency_type,
concurrency=concurrency,
Expand Down Expand Up @@ -154,7 +154,7 @@ async def _main_loop(self):

async with self.pubsub as pubsub: # pylint: disable=not-async-context-manager
counter = -1
await pubsub.subscribe(TASKS_CHANNEL)
await pubsub.subscribe(Config.tasks_channel())
while True:
self._logger.debug("[%s] Polling for a new task until it's available", self._pid)
message = await pubsub.poll()
Expand Down Expand Up @@ -191,7 +191,7 @@ class GruntWorker(BaseWorker):
"""

def __init__(self, app_import_path: str, poll_interval_s: float, worker_rate_limit: int):
self.pubsub: IPubSub = PubSub.get(url=REDIS_URL, poll_interval_s=poll_interval_s)
self.pubsub: IPubSub = PubSub.get(url=Config.broker_url(), poll_interval_s=poll_interval_s)
self._worker_rate_limit = worker_rate_limit
super().__init__(app_import_path=app_import_path)

Expand Down Expand Up @@ -295,7 +295,7 @@ async def _execute_task_and_publish(
task_id=task.id, ready=True, result=task_result, error=None
)
task_serialized = Serialization.serialize(obj=result)
result_channel = RESULTS_CHANNEL_TEMPLATE.format(task_id=task.id)
result_channel = Config.results_channel_template().format(task_id=task.id)
await pubsub.publish(channel=result_channel, message=task_serialized)

if semaphore is not None:
Expand Down

0 comments on commit aaeb820

Please sign in to comment.