From ea6679d8051fa6f5588992e2b3bee087491c85ef Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Sun, 21 Jul 2024 15:11:01 +0000 Subject: [PATCH 1/8] Switch from Celery to Procrastinate (wip) --- Dockerfile | 8 +- TODO.md | 2 - compose/docker-compose.base.yml | 50 +- compose/docker-compose.dev.yml | 22 +- compose/docker-compose.prod.yml | 31 +- compose/rabbitmq/rabbitmq.conf | 13 - poetry.lock | 474 +++--------------- pyproject.toml | 10 +- radis/__init__.py | 7 - radis/celery.py | 41 -- radis/conftest.py | 18 - radis/core/management/commands/celery_beat.py | 7 - .../core/management/commands/celery_worker.py | 7 - radis/core/models.py | 8 +- radis/core/processors.py | 73 +++ radis/core/tasks.py | 137 +---- radis/core/templates/core/admin_section.html | 6 - .../core/mail/analysis_job_failed.html | 1 - radis/core/urls.py | 4 - radis/core/views.py | 23 +- .../0012_switch_to_procrastinate.py | 44 ++ radis/rag/models.py | 25 +- radis/rag/processors.py | 120 +++++ radis/rag/tasks.py | 230 +++------ radis/rag/templates/rag/rag_task_detail.html | 4 - radis/rag/tests/unit/test_task.py | 12 +- radis/settings/base.py | 49 +- radis/settings/development.py | 3 - radis/settings/production.py | 7 - radis/settings/test.py | 10 +- tasks.py | 14 - 31 files changed, 436 insertions(+), 1024 deletions(-) delete mode 100644 compose/rabbitmq/rabbitmq.conf delete mode 100644 radis/celery.py delete mode 100644 radis/core/management/commands/celery_beat.py delete mode 100644 radis/core/management/commands/celery_worker.py create mode 100644 radis/core/processors.py create mode 100644 radis/rag/migrations/0012_switch_to_procrastinate.py create mode 100644 radis/rag/processors.py diff --git a/Dockerfile b/Dockerfile index f192c000..17cc44e3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ ENV PYTHONUNBUFFERED=1 \ # poetry # https://python-poetry.org/docs/#installing-with-the-official-installer # https://python-poetry.org/docs/configuration/#using-environment-variables - POETRY_VERSION=1.8.2 \ + POETRY_VERSION=1.8.3 \ # make poetry install to this location POETRY_HOME="/opt/poetry" \ # make poetry create the virtual environment in the project's root @@ -74,8 +74,7 @@ RUN playwright install --with-deps chromium # Required folders for RADIS RUN mkdir -p /var/www/radis/logs \ /var/www/radis/static \ - /var/www/radis/ssl \ - /var/www/radis/celery + /var/www/radis/ssl # will become mountpoint of our code WORKDIR /app @@ -89,7 +88,6 @@ COPY . /app/ # Required folders for RADIS RUN mkdir -p /var/www/radis/logs \ /var/www/radis/static \ - /var/www/radis/ssl \ - /var/www/radis/celery + /var/www/radis/ssl WORKDIR /app diff --git a/TODO.md b/TODO.md index 1fe38121..5669e78f 100644 --- a/TODO.md +++ b/TODO.md @@ -14,7 +14,6 @@ - Reference: name (unique), match (unique) - Remove unneeded templatetags - Are pandas and openpyxl needed as deps?! -- Remove Redis if not needed anymore ## Fix @@ -91,5 +90,4 @@ - Delete reset_dev_db and add reset option to populate_dev_db - globals.d.ts - rename all Alpine components to Uppercase -- Turn off debug logging in Celery - Add metaclass=ABCMeta to abstract core/models and core/views (also core/tables and core/filters even in RADIS) diff --git a/compose/docker-compose.base.yml b/compose/docker-compose.base.yml index 7ba913c4..0efa8201 100644 --- a/compose/docker-compose.base.yml +++ b/compose/docker-compose.base.yml @@ -7,20 +7,14 @@ x-app: &default-app USE_DOCKER: 1 DJANGO_STATIC_ROOT: "/var/www/radis/static/" DATABASE_URL: "psql://postgres:postgres@postgres.local:5432/postgres" - RABBITMQ_URL: "amqp://rabbit" - RABBIT_MANAGEMENT_HOST: "rabbit" - RABBIT_MANAGEMENT_PORT: "15672" - REDIS_URL: "redis://redis.local:6379/0" LLAMACPP_URL: "http://llamacpp.local:8080" - FLOWER_HOST: "flower.local" - FLOWER_PORT: "5555" services: init: <<: *default-app hostname: init.local volumes: - - radis_data:/var/www/radis + - web_data:/var/www/radis - /mnt:/mnt web: @@ -29,7 +23,7 @@ services: build: context: .. volumes: - - radis_data:/var/www/radis + - web_data:/var/www/radis - /mnt:/mnt worker_default: @@ -40,50 +34,12 @@ services: <<: *default-app hostname: worker_llm.local - celery_beat: - <<: *default-app - hostname: celery_beat.local - - flower: - <<: *default-app - hostname: flower.local - command: > - bash -c " - wait-for-it -s rabbit:5672 -t 100 && - celery --broker=amqp://rabbit/ flower --url_prefix=flower - " - postgres: image: postgres:16 hostname: postgres.local volumes: - postgres_data:/var/lib/postgresql/data - # RabbitMQ authentication can't be disabled. So when we try to log into - # the management console we have to use "guest" as username and password. - # The real authentication happens by ADIT itself, because the management - # console is behind a ProxyView. - rabbit: - image: rabbitmq:3.12.2-management - configs: - - source: rabbit_config - target: /etc/rabbitmq/rabbitmq.conf - volumes: - - rabbit_data:/var/lib/rabbitmq - - redis: - image: redis:7.2 - hostname: redis.local - volumes: - - redis_data:/data - -configs: - rabbit_config: - file: ./rabbitmq/rabbitmq.conf - volumes: - radis_data: - flower_data: + web_data: postgres_data: - rabbit_data: - redis_data: diff --git a/compose/docker-compose.dev.yml b/compose/docker-compose.dev.yml index 736354f7..d274a0de 100644 --- a/compose/docker-compose.dev.yml +++ b/compose/docker-compose.dev.yml @@ -50,21 +50,19 @@ services: worker_default: <<: *default-app - command: | - ./manage.py celery_worker -c 1 -Q default_queue --autoreload + command: > + bash -c " + wait-for-it -s postgres.local:5432 -t 60 && + ./manage.py worker -l debug -q default --autoreload + " worker_llm: <<: *default-app - command: | - ./manage.py celery_worker -c 1 -Q llm_queue --autoreload - - celery_beat: - <<: *default-app - command: | - ./manage.py celery_beat --autoreload - - flower: - <<: *default-app + command: > + bash -c " + wait-for-it -s postgres.local:5432 -t 60 && + ./manage.py worker -l debug -q llm --autoreload + " llamacpp_cpu: <<: *llamacpp diff --git a/compose/docker-compose.prod.yml b/compose/docker-compose.prod.yml index cda94297..af6cce4c 100644 --- a/compose/docker-compose.prod.yml +++ b/compose/docker-compose.prod.yml @@ -52,24 +52,21 @@ services: worker_default: <<: *default-app - command: ./manage.py celery_worker -Q default_queue + command: > + bash -c " + wait-for-it -s postgres.local:5432 -t 60 && + ./manage.py worker -q default + " deploy: <<: *deploy worker_llm: <<: *default-app - command: ./manage.py celery_worker -c 1 -Q llm_queue - deploy: - <<: *deploy - - celery_beat: - <<: *default-app - command: ./manage.py celery_beat - deploy: - <<: *deploy - - flower: - <<: *default-app + command: > + bash -c " + wait-for-it -s postgres.local:5432 -t 60 && + ./manage.py worker -q llm + " deploy: <<: *deploy @@ -101,13 +98,5 @@ services: deploy: <<: *deploy - rabbit: - deploy: - <<: *deploy - - redis: - deploy: - <<: *deploy - volumes: models_data: diff --git a/compose/rabbitmq/rabbitmq.conf b/compose/rabbitmq/rabbitmq.conf deleted file mode 100644 index 72048a07..00000000 --- a/compose/rabbitmq/rabbitmq.conf +++ /dev/null @@ -1,13 +0,0 @@ -# https://github.com/rabbitmq/rabbitmq-server/blob/main/deps/rabbit/docs/rabbitmq.conf.example - -# The defaults used in the docker image -loopback_users.guest = false -listeners.tcp.default = 5672 -management.tcp.port = 15672 - -# Extend the consumer timeout (the default is 30 minutes) as otherwise workers get killed -# that take longer to acknowledge a task. This timeout starts when a task is fetched -# by a worker. A worker can fetch multiple tasks which can be configured by -# CELERY_WORKER_PREFETCH_MULTIPLIER in our settings file. -# https://www.rabbitmq.com/consumers.html#acknowledgement-timeout -consumer_timeout = 86400000 # 24 hours diff --git a/poetry.lock b/poetry.lock index fedae9e6..84343667 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,7 +2,7 @@ [[package]] name = "adit-radis-shared" -version = "0.3.9" +version = "0.5.0" description = "Shared Django apps between ADIT and RADIS" optional = false python-versions = "^3.11" @@ -25,6 +25,7 @@ django-registration-redux = "^2.10" django-revproxy = "^0.12.0" django-tables2 = "^2.3.3" djangorestframework = "^3.13.1" +procrastinate = {version = "^2.8.0", extras = ["django"]} psycopg = {version = "^3.1.12", extras = ["binary"]} pytz = "^2024.1" toml = "^0.10.2" @@ -36,8 +37,8 @@ whitenoise = "^6.0.0" [package.source] type = "git" url = "https://github.com/openradx/adit-radis-shared.git" -reference = "v0.3.9" -resolved_reference = "e82cb07c3387f9adab20514e135082a8dd75e5b7" +reference = "v0.5.0" +resolved_reference = "303374246d1af2e1d66957a13a874a4195ed3877" [[package]] name = "adrf" @@ -66,20 +67,6 @@ files = [ {file = "aiofiles-23.2.1.tar.gz", hash = "sha256:84ec2218d8419404abcb9f0c02df3f34c6e0a68ed41072acfb1cef5cbc29051a"}, ] -[[package]] -name = "amqp" -version = "5.2.0" -description = "Low-level AMQP client for Python (fork of amqplib)." -optional = false -python-versions = ">=3.6" -files = [ - {file = "amqp-5.2.0-py3-none-any.whl", hash = "sha256:827cb12fb0baa892aad844fd95258143bce4027fdac4fccddbc43330fd281637"}, - {file = "amqp-5.2.0.tar.gz", hash = "sha256:a1ecff425ad063ad42a486c902807d1482311481c8ad95a72694b2975e75f7fd"}, -] - -[package.dependencies] -vine = ">=5.0.0,<6.0.0" - [[package]] name = "annotated-types" version = "0.7.0" @@ -165,17 +152,6 @@ files = [ {file = "async_property-0.2.2.tar.gz", hash = "sha256:17d9bd6ca67e27915a75d92549df64b5c7174e9dc806b30a3934dc4ff0506380"}, ] -[[package]] -name = "async-timeout" -version = "4.0.3" -description = "Timeout context manager for asyncio programs" -optional = false -python-versions = ">=3.7" -files = [ - {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, - {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, -] - [[package]] name = "asyncinotify" version = "4.0.9" @@ -252,88 +228,6 @@ six = "*" [package.extras] visualize = ["Twisted (>=16.1.1)", "graphviz (>0.5.1)"] -[[package]] -name = "billiard" -version = "4.2.0" -description = "Python multiprocessing fork with improvements and bugfixes" -optional = false -python-versions = ">=3.7" -files = [ - {file = "billiard-4.2.0-py3-none-any.whl", hash = "sha256:07aa978b308f334ff8282bd4a746e681b3513db5c9a514cbdd810cbbdc19714d"}, - {file = "billiard-4.2.0.tar.gz", hash = "sha256:9a3c3184cb275aa17a732f93f65b20c525d3d9f253722d26a82194803ade5a2c"}, -] - -[[package]] -name = "celery" -version = "5.4.0" -description = "Distributed Task Queue." -optional = false -python-versions = ">=3.8" -files = [ - {file = "celery-5.4.0-py3-none-any.whl", hash = "sha256:369631eb580cf8c51a82721ec538684994f8277637edde2dfc0dacd73ed97f64"}, - {file = "celery-5.4.0.tar.gz", hash = "sha256:504a19140e8d3029d5acad88330c541d4c3f64c789d85f94756762d8bca7e706"}, -] - -[package.dependencies] -billiard = ">=4.2.0,<5.0" -click = ">=8.1.2,<9.0" -click-didyoumean = ">=0.3.0" -click-plugins = ">=1.1.1" -click-repl = ">=0.2.0" -kombu = ">=5.3.4,<6.0" -python-dateutil = ">=2.8.2" -redis = {version = ">=4.5.2,<4.5.5 || >4.5.5,<6.0.0", optional = true, markers = "extra == \"redis\""} -tzdata = ">=2022.7" -vine = ">=5.1.0,<6.0" - -[package.extras] -arangodb = ["pyArango (>=2.0.2)"] -auth = ["cryptography (==42.0.5)"] -azureblockblob = ["azure-storage-blob (>=12.15.0)"] -brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"] -cassandra = ["cassandra-driver (>=3.25.0,<4)"] -consul = ["python-consul2 (==0.1.5)"] -cosmosdbsql = ["pydocumentdb (==2.3.5)"] -couchbase = ["couchbase (>=3.0.0)"] -couchdb = ["pycouchdb (==1.14.2)"] -django = ["Django (>=2.2.28)"] -dynamodb = ["boto3 (>=1.26.143)"] -elasticsearch = ["elastic-transport (<=8.13.0)", "elasticsearch (<=8.13.0)"] -eventlet = ["eventlet (>=0.32.0)"] -gcs = ["google-cloud-storage (>=2.10.0)"] -gevent = ["gevent (>=1.5.0)"] -librabbitmq = ["librabbitmq (>=2.0.0)"] -memcache = ["pylibmc (==1.6.3)"] -mongodb = ["pymongo[srv] (>=4.0.2)"] -msgpack = ["msgpack (==1.0.8)"] -pymemcache = ["python-memcached (>=1.61)"] -pyro = ["pyro4 (==4.82)"] -pytest = ["pytest-celery[all] (>=1.0.0)"] -redis = ["redis (>=4.5.2,!=4.5.5,<6.0.0)"] -s3 = ["boto3 (>=1.26.143)"] -slmq = ["softlayer-messaging (>=1.0.3)"] -solar = ["ephem (==4.1.5)"] -sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] -sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.4)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] -tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"] -yaml = ["PyYAML (>=3.10)"] -zookeeper = ["kazoo (>=1.3.1)"] -zstd = ["zstandard (==0.22.0)"] - -[[package]] -name = "celery-types" -version = "0.22.0" -description = "Type stubs for Celery and its related packages" -optional = false -python-versions = ">=3.9,<4.0" -files = [ - {file = "celery_types-0.22.0-py3-none-any.whl", hash = "sha256:79a66637d1d6af5992d1dc80259d9538869941325e966006f1e795220519b9ac"}, - {file = "celery_types-0.22.0.tar.gz", hash = "sha256:0ecad2fa5a6eded0a1f919e5e1e381cc2ff0635fe4b21db53b4661b6876d5b30"}, -] - -[package.dependencies] -typing-extensions = ">=4.9.0,<5.0.0" - [[package]] name = "certifi" version = "2024.7.4" @@ -428,27 +322,6 @@ Django = ">=4.2" daphne = ["daphne (>=4.0.0)"] tests = ["async-timeout", "coverage (>=4.5,<5.0)", "pytest", "pytest-asyncio", "pytest-django"] -[[package]] -name = "channels-redis" -version = "4.2.0" -description = "Redis-backed ASGI channel layer implementation" -optional = false -python-versions = ">=3.8" -files = [ - {file = "channels_redis-4.2.0-py3-none-any.whl", hash = "sha256:2c5b944a39bd984b72aa8005a3ae11637bf29b5092adeb91c9aad4ab819a8ac4"}, - {file = "channels_redis-4.2.0.tar.gz", hash = "sha256:01c26c4d5d3a203f104bba9e5585c0305a70df390d21792386586068162027fd"}, -] - -[package.dependencies] -asgiref = ">=3.2.10,<4" -channels = "*" -msgpack = ">=1.0,<2.0" -redis = ">=4.6" - -[package.extras] -cryptography = ["cryptography (>=1.3.0)"] -tests = ["async-timeout", "cryptography (>=1.3.0)", "pytest", "pytest-asyncio", "pytest-timeout"] - [[package]] name = "charset-normalizer" version = "3.3.2" @@ -562,55 +435,6 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} -[[package]] -name = "click-didyoumean" -version = "0.3.1" -description = "Enables git-like *did-you-mean* feature in click" -optional = false -python-versions = ">=3.6.2" -files = [ - {file = "click_didyoumean-0.3.1-py3-none-any.whl", hash = "sha256:5c4bb6007cfea5f2fd6583a2fb6701a22a41eb98957e63d0fac41c10e7c3117c"}, - {file = "click_didyoumean-0.3.1.tar.gz", hash = "sha256:4f82fdff0dbe64ef8ab2279bd6aa3f6a99c3b28c05aa09cbfc07c9d7fbb5a463"}, -] - -[package.dependencies] -click = ">=7" - -[[package]] -name = "click-plugins" -version = "1.1.1" -description = "An extension module for click to enable registering CLI commands via setuptools entry-points." -optional = false -python-versions = "*" -files = [ - {file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"}, - {file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"}, -] - -[package.dependencies] -click = ">=4.0" - -[package.extras] -dev = ["coveralls", "pytest (>=3.6)", "pytest-cov", "wheel"] - -[[package]] -name = "click-repl" -version = "0.3.0" -description = "REPL plugin for Click" -optional = false -python-versions = ">=3.6" -files = [ - {file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"}, - {file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"}, -] - -[package.dependencies] -click = ">=7.0" -prompt-toolkit = ">=3.0.36" - -[package.extras] -testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"] - [[package]] name = "colorama" version = "0.4.6" @@ -732,6 +556,21 @@ django-crispy-forms = ">=2" [package.extras] test = ["pytest", "pytest-django"] +[[package]] +name = "croniter" +version = "2.0.7" +description = "croniter provides iteration for datetime object with cron like format" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.6" +files = [ + {file = "croniter-2.0.7-py2.py3-none-any.whl", hash = "sha256:f15e80828d23920c4bb7f4d9340b932c9dcabecafc7775703c8b36d1253ed526"}, + {file = "croniter-2.0.7.tar.gz", hash = "sha256:1041b912b4b1e03751a0993531becf77851ae6e8b334c9c76ffeffb8f055f53f"}, +] + +[package.dependencies] +python-dateutil = "*" +pytz = ">2021.1" + [[package]] name = "cryptography" version = "42.0.8" @@ -1190,28 +1029,6 @@ PyYAML = ">=6.0,<7.0" regex = ">=2023.0.0,<2024.0.0" tqdm = ">=4.62.2,<5.0.0" -[[package]] -name = "docker" -version = "7.1.0" -description = "A Python library for the Docker Engine API." -optional = false -python-versions = ">=3.8" -files = [ - {file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"}, - {file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"}, -] - -[package.dependencies] -pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} -requests = ">=2.26.0" -urllib3 = ">=1.26.0" - -[package.extras] -dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"] -docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] -ssh = ["paramiko (>=2.4.3)"] -websockets = ["websocket-client (>=1.3.0)"] - [[package]] name = "docopt" version = "0.6.2" @@ -1289,24 +1106,6 @@ files = [ [package.dependencies] python-dateutil = ">=2.4" -[[package]] -name = "flower" -version = "2.0.1" -description = "Celery Flower" -optional = false -python-versions = ">=3.7" -files = [ - {file = "flower-2.0.1-py2.py3-none-any.whl", hash = "sha256:9db2c621eeefbc844c8dd88be64aef61e84e2deb29b271e02ab2b5b9f01068e2"}, - {file = "flower-2.0.1.tar.gz", hash = "sha256:5ab717b979530770c16afb48b50d2a98d23c3e9fe39851dcf6bc4d01845a02a0"}, -] - -[package.dependencies] -celery = ">=5.0.5" -humanize = "*" -prometheus-client = ">=0.8.0" -pytz = "*" -tornado = ">=5.0.0,<7.0.0" - [[package]] name = "greenlet" version = "3.0.3" @@ -1725,38 +1524,6 @@ traitlets = ">=5.3" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"] -[[package]] -name = "kombu" -version = "5.3.7" -description = "Messaging library for Python." -optional = false -python-versions = ">=3.8" -files = [ - {file = "kombu-5.3.7-py3-none-any.whl", hash = "sha256:5634c511926309c7f9789f1433e9ed402616b56836ef9878f01bd59267b4c7a9"}, - {file = "kombu-5.3.7.tar.gz", hash = "sha256:011c4cd9a355c14a1de8d35d257314a1d2456d52b7140388561acac3cf1a97bf"}, -] - -[package.dependencies] -amqp = ">=5.1.1,<6.0.0" -vine = "*" - -[package.extras] -azureservicebus = ["azure-servicebus (>=7.10.0)"] -azurestoragequeues = ["azure-identity (>=1.12.0)", "azure-storage-queue (>=12.6.0)"] -confluentkafka = ["confluent-kafka (>=2.2.0)"] -consul = ["python-consul2"] -librabbitmq = ["librabbitmq (>=2.0.0)"] -mongodb = ["pymongo (>=4.1.1)"] -msgpack = ["msgpack"] -pyro = ["pyro4"] -qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] -redis = ["redis (>=4.5.2,!=4.5.5,!=5.0.2)"] -slmq = ["softlayer-messaging (>=1.0.3)"] -sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] -sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] -yaml = ["PyYAML (>=3.10)"] -zookeeper = ["kazoo (>=2.8.0)"] - [[package]] name = "markdown" version = "3.6" @@ -1786,71 +1553,6 @@ files = [ [package.dependencies] traitlets = "*" -[[package]] -name = "msgpack" -version = "1.0.8" -description = "MessagePack serializer" -optional = false -python-versions = ">=3.8" -files = [ - {file = "msgpack-1.0.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:505fe3d03856ac7d215dbe005414bc28505d26f0c128906037e66d98c4e95868"}, - {file = "msgpack-1.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b7842518a63a9f17107eb176320960ec095a8ee3b4420b5f688e24bf50c53c"}, - {file = "msgpack-1.0.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:376081f471a2ef24828b83a641a02c575d6103a3ad7fd7dade5486cad10ea659"}, - {file = "msgpack-1.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e390971d082dba073c05dbd56322427d3280b7cc8b53484c9377adfbae67dc2"}, - {file = "msgpack-1.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00e073efcba9ea99db5acef3959efa45b52bc67b61b00823d2a1a6944bf45982"}, - {file = "msgpack-1.0.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82d92c773fbc6942a7a8b520d22c11cfc8fd83bba86116bfcf962c2f5c2ecdaa"}, - {file = "msgpack-1.0.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9ee32dcb8e531adae1f1ca568822e9b3a738369b3b686d1477cbc643c4a9c128"}, - {file = "msgpack-1.0.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e3aa7e51d738e0ec0afbed661261513b38b3014754c9459508399baf14ae0c9d"}, - {file = "msgpack-1.0.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:69284049d07fce531c17404fcba2bb1df472bc2dcdac642ae71a2d079d950653"}, - {file = "msgpack-1.0.8-cp310-cp310-win32.whl", hash = "sha256:13577ec9e247f8741c84d06b9ece5f654920d8365a4b636ce0e44f15e07ec693"}, - {file = "msgpack-1.0.8-cp310-cp310-win_amd64.whl", hash = "sha256:e532dbd6ddfe13946de050d7474e3f5fb6ec774fbb1a188aaf469b08cf04189a"}, - {file = "msgpack-1.0.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9517004e21664f2b5a5fd6333b0731b9cf0817403a941b393d89a2f1dc2bd836"}, - {file = "msgpack-1.0.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d16a786905034e7e34098634b184a7d81f91d4c3d246edc6bd7aefb2fd8ea6ad"}, - {file = "msgpack-1.0.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2872993e209f7ed04d963e4b4fbae72d034844ec66bc4ca403329db2074377b"}, - {file = "msgpack-1.0.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c330eace3dd100bdb54b5653b966de7f51c26ec4a7d4e87132d9b4f738220ba"}, - {file = "msgpack-1.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b5c044f3eff2a6534768ccfd50425939e7a8b5cf9a7261c385de1e20dcfc85"}, - {file = "msgpack-1.0.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1876b0b653a808fcd50123b953af170c535027bf1d053b59790eebb0aeb38950"}, - {file = "msgpack-1.0.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:dfe1f0f0ed5785c187144c46a292b8c34c1295c01da12e10ccddfc16def4448a"}, - {file = "msgpack-1.0.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3528807cbbb7f315bb81959d5961855e7ba52aa60a3097151cb21956fbc7502b"}, - {file = "msgpack-1.0.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e2f879ab92ce502a1e65fce390eab619774dda6a6ff719718069ac94084098ce"}, - {file = "msgpack-1.0.8-cp311-cp311-win32.whl", hash = "sha256:26ee97a8261e6e35885c2ecd2fd4a6d38252246f94a2aec23665a4e66d066305"}, - {file = "msgpack-1.0.8-cp311-cp311-win_amd64.whl", hash = "sha256:eadb9f826c138e6cf3c49d6f8de88225a3c0ab181a9b4ba792e006e5292d150e"}, - {file = "msgpack-1.0.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:114be227f5213ef8b215c22dde19532f5da9652e56e8ce969bf0a26d7c419fee"}, - {file = "msgpack-1.0.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d661dc4785affa9d0edfdd1e59ec056a58b3dbb9f196fa43587f3ddac654ac7b"}, - {file = "msgpack-1.0.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d56fd9f1f1cdc8227d7b7918f55091349741904d9520c65f0139a9755952c9e8"}, - {file = "msgpack-1.0.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0726c282d188e204281ebd8de31724b7d749adebc086873a59efb8cf7ae27df3"}, - {file = "msgpack-1.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8db8e423192303ed77cff4dce3a4b88dbfaf43979d280181558af5e2c3c71afc"}, - {file = "msgpack-1.0.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99881222f4a8c2f641f25703963a5cefb076adffd959e0558dc9f803a52d6a58"}, - {file = "msgpack-1.0.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b5505774ea2a73a86ea176e8a9a4a7c8bf5d521050f0f6f8426afe798689243f"}, - {file = "msgpack-1.0.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:ef254a06bcea461e65ff0373d8a0dd1ed3aa004af48839f002a0c994a6f72d04"}, - {file = "msgpack-1.0.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e1dd7839443592d00e96db831eddb4111a2a81a46b028f0facd60a09ebbdd543"}, - {file = "msgpack-1.0.8-cp312-cp312-win32.whl", hash = "sha256:64d0fcd436c5683fdd7c907eeae5e2cbb5eb872fafbc03a43609d7941840995c"}, - {file = "msgpack-1.0.8-cp312-cp312-win_amd64.whl", hash = "sha256:74398a4cf19de42e1498368c36eed45d9528f5fd0155241e82c4082b7e16cffd"}, - {file = "msgpack-1.0.8-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0ceea77719d45c839fd73abcb190b8390412a890df2f83fb8cf49b2a4b5c2f40"}, - {file = "msgpack-1.0.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1ab0bbcd4d1f7b6991ee7c753655b481c50084294218de69365f8f1970d4c151"}, - {file = "msgpack-1.0.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1cce488457370ffd1f953846f82323cb6b2ad2190987cd4d70b2713e17268d24"}, - {file = "msgpack-1.0.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3923a1778f7e5ef31865893fdca12a8d7dc03a44b33e2a5f3295416314c09f5d"}, - {file = "msgpack-1.0.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a22e47578b30a3e199ab067a4d43d790249b3c0587d9a771921f86250c8435db"}, - {file = "msgpack-1.0.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd739c9251d01e0279ce729e37b39d49a08c0420d3fee7f2a4968c0576678f77"}, - {file = "msgpack-1.0.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d3420522057ebab1728b21ad473aa950026d07cb09da41103f8e597dfbfaeb13"}, - {file = "msgpack-1.0.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5845fdf5e5d5b78a49b826fcdc0eb2e2aa7191980e3d2cfd2a30303a74f212e2"}, - {file = "msgpack-1.0.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6a0e76621f6e1f908ae52860bdcb58e1ca85231a9b0545e64509c931dd34275a"}, - {file = "msgpack-1.0.8-cp38-cp38-win32.whl", hash = "sha256:374a8e88ddab84b9ada695d255679fb99c53513c0a51778796fcf0944d6c789c"}, - {file = "msgpack-1.0.8-cp38-cp38-win_amd64.whl", hash = "sha256:f3709997b228685fe53e8c433e2df9f0cdb5f4542bd5114ed17ac3c0129b0480"}, - {file = "msgpack-1.0.8-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f51bab98d52739c50c56658cc303f190785f9a2cd97b823357e7aeae54c8f68a"}, - {file = "msgpack-1.0.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:73ee792784d48aa338bba28063e19a27e8d989344f34aad14ea6e1b9bd83f596"}, - {file = "msgpack-1.0.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f9904e24646570539a8950400602d66d2b2c492b9010ea7e965025cb71d0c86d"}, - {file = "msgpack-1.0.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e75753aeda0ddc4c28dce4c32ba2f6ec30b1b02f6c0b14e547841ba5b24f753f"}, - {file = "msgpack-1.0.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5dbf059fb4b7c240c873c1245ee112505be27497e90f7c6591261c7d3c3a8228"}, - {file = "msgpack-1.0.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4916727e31c28be8beaf11cf117d6f6f188dcc36daae4e851fee88646f5b6b18"}, - {file = "msgpack-1.0.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7938111ed1358f536daf311be244f34df7bf3cdedb3ed883787aca97778b28d8"}, - {file = "msgpack-1.0.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:493c5c5e44b06d6c9268ce21b302c9ca055c1fd3484c25ba41d34476c76ee746"}, - {file = "msgpack-1.0.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fbb160554e319f7b22ecf530a80a3ff496d38e8e07ae763b9e82fadfe96f273"}, - {file = "msgpack-1.0.8-cp39-cp39-win32.whl", hash = "sha256:f9af38a89b6a5c04b7d18c492c8ccf2aee7048aff1ce8437c4683bb5a1df893d"}, - {file = "msgpack-1.0.8-cp39-cp39-win_amd64.whl", hash = "sha256:ed59dd52075f8fc91da6053b12e8c89e37aa043f8986efd89e61fae69dc1b011"}, - {file = "msgpack-1.0.8.tar.gz", hash = "sha256:95c02b0e27e706e48d0e5426d1710ca78e0f0628d6e89d5b5a5b91a5f12274f3"}, -] - [[package]] name = "nest-asyncio" version = "1.6.0" @@ -2073,6 +1775,17 @@ files = [ {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, ] +[[package]] +name = "pebble" +version = "5.0.7" +description = "Threading and multiprocessing eye-candy." +optional = false +python-versions = ">=3.6" +files = [ + {file = "Pebble-5.0.7-py3-none-any.whl", hash = "sha256:f1742f2a62e8544e722c7b387211fb1a06038ca8cda322e5d55c84c793fd8d7d"}, + {file = "Pebble-5.0.7.tar.gz", hash = "sha256:2784c147766f06388cea784084b14bec93fdbaa793830f1983155aa330a2a6e4"}, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -2150,18 +1863,30 @@ files = [ ] [[package]] -name = "prometheus-client" -version = "0.20.0" -description = "Python client for the Prometheus monitoring system." +name = "procrastinate" +version = "2.8.0" +description = "Postgres-based distributed task processing library" optional = false -python-versions = ">=3.8" +python-versions = "<4.0,>=3.8" files = [ - {file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"}, - {file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"}, + {file = "procrastinate-2.8.0-py3-none-any.whl", hash = "sha256:8dd785ea8afc04d72a1e45f13a480e46745a559ae62b3ef46cf8165b88a69a64"}, + {file = "procrastinate-2.8.0.tar.gz", hash = "sha256:c5c9c1ed3d95148139dd41e0a046e7bd3f1e712cc5111b023b699d82f209d981"}, ] +[package.dependencies] +anyio = "*" +asgiref = "*" +attrs = "*" +croniter = "*" +django = {version = ">=2.2", optional = true, markers = "extra == \"django\""} +psycopg = {version = ">=3.1.13,<4.0.0", extras = ["pool"]} +python-dateutil = "*" + [package.extras] -twisted = ["twisted"] +aiopg = ["aiopg", "psycopg2-binary"] +django = ["django (>=2.2)"] +psycopg2 = ["psycopg2-binary"] +sqlalchemy = ["sqlalchemy (>=2.0,<3.0)"] [[package]] name = "prompt-toolkit" @@ -2219,6 +1944,7 @@ files = [ [package.dependencies] psycopg-binary = {version = "3.2.1", optional = true, markers = "implementation_name != \"pypy\" and extra == \"binary\""} +psycopg-pool = {version = "*", optional = true, markers = "extra == \"pool\""} typing-extensions = ">=4.4" tzdata = {version = "*", markers = "sys_platform == \"win32\""} @@ -2292,6 +2018,20 @@ files = [ {file = "psycopg_binary-3.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:921f0c7f39590763d64a619de84d1b142587acc70fd11cbb5ba8fa39786f3073"}, ] +[[package]] +name = "psycopg-pool" +version = "3.2.2" +description = "Connection Pool for Psycopg" +optional = false +python-versions = ">=3.8" +files = [ + {file = "psycopg_pool-3.2.2-py3-none-any.whl", hash = "sha256:273081d0fbfaced4f35e69200c89cb8fbddfe277c38cc86c235b90a2ec2c8153"}, + {file = "psycopg_pool-3.2.2.tar.gz", hash = "sha256:9e22c370045f6d7f2666a5ad1b0caf345f9f1912195b0b25d0d3bcc4f3a7389c"}, +] + +[package.dependencies] +typing-extensions = ">=4.4" + [[package]] name = "ptyprocess" version = "0.7.0" @@ -2627,31 +2367,6 @@ requests = ">=2.9" [package.extras] test = ["black (>=22.1.0)", "flake8 (>=4.0.1)", "pre-commit (>=2.17.0)", "pytest-localserver (>=0.7.1)", "tox (>=3.24.5)"] -[[package]] -name = "pytest-celery" -version = "1.0.1" -description = "Pytest plugin for Celery" -optional = false -python-versions = "<4.0,>=3.8" -files = [ - {file = "pytest_celery-1.0.1-py3-none-any.whl", hash = "sha256:8f0068f0b5deb3123c76ae56327d40ece488c622daee54b3c5ff968c503df841"}, - {file = "pytest_celery-1.0.1.tar.gz", hash = "sha256:8ab12f2f16946e131c315efce2d71fa3b74a05269077fde04f96a6048b249377"}, -] - -[package.dependencies] -celery = "*" -debugpy = ">=1.8.1,<2.0.0" -docker = ">=7.1.0,<8.0.0" -psutil = ">=5.9.7" -pytest-docker-tools = ">=3.1.3" -setuptools = ">=69.1.0" -tenacity = ">=8.5.0" - -[package.extras] -all = ["python-memcached", "redis"] -memcached = ["python-memcached"] -redis = ["redis"] - [[package]] name = "pytest-cov" version = "5.0.0" @@ -2688,21 +2403,6 @@ pytest = ">=7.0.0" docs = ["sphinx", "sphinx-rtd-theme"] testing = ["Django", "django-configurations (>=2.0)"] -[[package]] -name = "pytest-docker-tools" -version = "3.1.3" -description = "Docker integration tests for pytest" -optional = false -python-versions = ">=3.7.0,<4.0.0" -files = [ - {file = "pytest_docker_tools-3.1.3-py3-none-any.whl", hash = "sha256:63e659043160f41d89f94ea42616102594bcc85682aac394fcbc14f14cd1b189"}, - {file = "pytest_docker_tools-3.1.3.tar.gz", hash = "sha256:c7e28841839d67b3ac80ad7b345b953701d5ae61ffda97586114244292aeacc0"}, -] - -[package.dependencies] -docker = ">=4.3.1" -pytest = ">=6.0.1" - [[package]] name = "pytest-mock" version = "3.14.0" @@ -3049,24 +2749,6 @@ files = [ [package.dependencies] cffi = {version = "*", markers = "implementation_name == \"pypy\""} -[[package]] -name = "redis" -version = "5.0.7" -description = "Python client for Redis database and key-value store" -optional = false -python-versions = ">=3.7" -files = [ - {file = "redis-5.0.7-py3-none-any.whl", hash = "sha256:0e479e24da960c690be5d9b96d21f7b918a98c0cf49af3b6fafaa0753f93a0db"}, - {file = "redis-5.0.7.tar.gz", hash = "sha256:8f611490b93c8109b50adc317b31bfd84fff31def3475b92e7e80bf39f48175b"}, -] - -[package.dependencies] -async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} - -[package.extras] -hiredis = ["hiredis (>=1.0.0)"] -ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] - [[package]] name = "regex" version = "2023.12.25" @@ -3312,21 +2994,6 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] -[[package]] -name = "tenacity" -version = "8.5.0" -description = "Retry code until it succeeds" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687"}, - {file = "tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78"}, -] - -[package.extras] -doc = ["reno", "sphinx"] -test = ["pytest", "tornado (>=4.5)", "typeguard"] - [[package]] name = "text-unidecode" version = "1.3" @@ -3659,17 +3326,6 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] -[[package]] -name = "vine" -version = "5.1.0" -description = "Python promises." -optional = false -python-versions = ">=3.6" -files = [ - {file = "vine-5.1.0-py3-none-any.whl", hash = "sha256:40fdf3c48b2cfe1c38a49e9ae2da6fda88e4794c810050a728bd7413811fb1dc"}, - {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"}, -] - [[package]] name = "wait-for-it" version = "2.2.2" @@ -3899,4 +3555,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<4.0" -content-hash = "0511449d40b44bf2ce5a964bbf02bda7e9558a728e0d8b74a4021b8d464cd2df" +content-hash = "50fd8cdea72e1e9c433b305bddf1b0e6426f5783c4949df6f4dc80c42ead0160" diff --git a/pyproject.toml b/pyproject.toml index 7b2f80d7..6675657a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,13 +6,11 @@ authors = ["medihack "] license = "GPL-3.0-or-later" [tool.poetry.dependencies] -adit-radis-shared = { git = "https://github.com/openradx/adit-radis-shared.git", tag = "v0.3.9" } +adit-radis-shared = { git = "https://github.com/openradx/adit-radis-shared.git", tag = "v0.5.0" } adrf = "^0.1.4" aiofiles = "^23.1.0" asyncinotify = "^4.0.1" -celery = { extras = ["redis"], version = "^5.3.5" } channels = "^4.0.0" -channels-redis = "^4.0.0" crispy-bootstrap5 = "^2024.2" cryptography = "^42.0.2" daphne = "^4.1.0" @@ -29,16 +27,16 @@ django-registration-redux = "^2.10" django-revproxy = { git = "https://github.com/jazzband/django-revproxy.git" } django-tables2 = "^2.3.3" djangorestframework = "^3.13.1" -flower = "^2.0.0" humanize = "^4.0.0" Markdown = "^3.3.7" openai = "^1.12.0" openpyxl = "^3.1.2" pandas = "^2.0.1" +pebble = "^5.0.7" +procrastinate = { extras = ["django"], version = "^2.8.0" } psycopg = { extras = ["binary"], version = "^3.1.13" } pyparsing = "^3.1.2" python = ">=3.11,<4.0" -redis = "^5.0.3" toml = "^0.10.2" Twisted = { extras = ["tls", "http2"], version = "^24.3.0" } wait-for-it = "^2.2.2" @@ -46,7 +44,6 @@ watchfiles = "^0.22.0" whitenoise = "^6.0.0" [tool.poetry.group.dev.dependencies] -celery-types = "^0.22.0" debugpy = "^1.8.1" django-browser-reload = "^1.11.0" django-debug-permissions = "^1.0.0" @@ -66,7 +63,6 @@ pydicom = "^2.4.3" pyright = "^1.1.336" pytest = "^8.1.1" pytest-asyncio = "^0.23.5" -pytest-celery = "^1.0.0" pytest-cov = "^5.0.0" pytest-django = "^4.5.2" pytest-mock = "^3.10.0" diff --git a/radis/__init__.py b/radis/__init__.py index 0165ba0d..e69de29b 100644 --- a/radis/__init__.py +++ b/radis/__init__.py @@ -1,7 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -# This will make sure the app is always imported when -# Django starts so that shared_task will use this app. -from .celery import app as celery_app - -__all__ = ("celery_app",) diff --git a/radis/celery.py b/radis/celery.py deleted file mode 100644 index c64980af..00000000 --- a/radis/celery.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -import os - -from celery import Celery - -# set the default Django settings module for the 'celery' program. -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "radis.settings.development") - -app = Celery("radis") - -# Using a string here means the worker doesn't have to serialize -# the configuration object to child processes. -# - namespace='CELERY' means all celery-related configuration keys -# should have a `CELERY_` prefix. -app.config_from_object("django.conf:settings", namespace="CELERY") - -# If priority queues do not work this might be another option, see -# https://stackoverflow.com/a/47980598/166229 -# app.conf.task_queue_max_priority = 10 - -from celery.signals import setup_logging # noqa: E402 - - -# When setting up logging like this then CELERY_WORKER_HIJACK_ROOT_LOGGER -# has no effect anymore as it Celery will never try to hijack it. -# It will be forced to take the logger configuration from our Django -# settings. The downside is that we don't get the Celery logger -# (get_task_logger) anymore that prints out the Celery task ID with -# the log message, so we you the default Python logger instead. -@setup_logging.connect -def config_loggers(*args, **kwargs): - from logging.config import dictConfig - - from django.conf import settings - - dictConfig(settings.LOGGING) - - -# Load task modules from all registered Django app configs. -app.autodiscover_tasks() diff --git a/radis/conftest.py b/radis/conftest.py index 99477421..64c56afb 100644 --- a/radis/conftest.py +++ b/radis/conftest.py @@ -1,12 +1,5 @@ -from multiprocessing import Process - import nest_asyncio -import pytest from adit_radis_shared.conftest import * # noqa: F403 -from django.core.management import call_command -from faker import Faker - -fake = Faker() def pytest_configure(): @@ -17,14 +10,3 @@ def pytest_configure(): # https://github.com/pytest-dev/pytest-asyncio/issues/543 # https://github.com/microsoft/playwright-pytest/issues/167 nest_asyncio.apply() - - -@pytest.fixture -def radis_celery_worker(): - def start_worker(): - call_command("celery_worker", "-Q", "test_queue") - - p = Process(target=start_worker) - p.start() - yield - p.terminate() diff --git a/radis/core/management/commands/celery_beat.py b/radis/core/management/commands/celery_beat.py deleted file mode 100644 index 381bb0d3..00000000 --- a/radis/core/management/commands/celery_beat.py +++ /dev/null @@ -1,7 +0,0 @@ -from adit_radis_shared.common.management.base.celery_beat import CeleryBeatCommand -from django.conf import settings - - -class Command(CeleryBeatCommand): - project = "radis" - paths_to_watch = [settings.BASE_DIR / "radis"] diff --git a/radis/core/management/commands/celery_worker.py b/radis/core/management/commands/celery_worker.py deleted file mode 100644 index 794171e7..00000000 --- a/radis/core/management/commands/celery_worker.py +++ /dev/null @@ -1,7 +0,0 @@ -from adit_radis_shared.common.management.base.celery_worker import CeleryWorkerCommand -from django.conf import settings - - -class Command(CeleryWorkerCommand): - project = "radis" - paths_to_watch = [settings.BASE_DIR / "radis"] diff --git a/radis/core/models.py b/radis/core/models.py index 1b95d261..cf9a37a4 100644 --- a/radis/core/models.py +++ b/radis/core/models.py @@ -3,6 +3,7 @@ from django.conf import settings from django.db import models from django.utils import timezone +from procrastinate.contrib.django.models import ProcrastinateJob from radis.core.utils.model_utils import reset_tasks @@ -196,14 +197,17 @@ class Status(models.TextChoices): id: int job_id: int job = models.ForeignKey(AnalysisJob, on_delete=models.CASCADE, related_name="tasks") - celery_task_id = models.CharField(max_length=255) + queued_job_id: int | None + queued_job = models.OneToOneField( + ProcrastinateJob, null=True, on_delete=models.SET_NULL, related_name="+" + ) status = models.CharField( max_length=2, choices=Status.choices, default=Status.PENDING, ) get_status_display: Callable[[], str] - retries = models.PositiveSmallIntegerField(default=0) + attempts = models.PositiveSmallIntegerField(default=0) message = models.TextField(blank=True, default="") log = models.TextField(blank=True, default="") created_at = models.DateTimeField(auto_now_add=True) diff --git a/radis/core/processors.py b/radis/core/processors.py new file mode 100644 index 00000000..76ff6879 --- /dev/null +++ b/radis/core/processors.py @@ -0,0 +1,73 @@ +import logging +import traceback + +from django.utils import timezone + +from .models import AnalysisJob, AnalysisTask + +logger = logging.getLogger(__name__) + + +class AnalysisTaskProcessor: + def start(self, task: AnalysisTask) -> None: + job = task.job + + logger.info("Start processing task %s", task) + + # Jobs are canceled by the AnalysisJobCancelView and tasks are also revoked there, + # but it could happen that the task was already picked up by a worker or under rare + # circumstances will nevertheless get picked up by a worker (e.g. the worker crashes + # and forgot its revoked tasks). We then just ignore that task. + if ( + job.status == AnalysisJob.Status.CANCELING + or job.status == AnalysisJob.Status.CANCELED + or task.status == AnalysisTask.Status.CANCELED + ): + task.status = task.Status.CANCELED + task.started_at = timezone.now() + task.ended_at = timezone.now() + task.save() + job.update_job_state() + return + + assert task.status == task.Status.PENDING + + # When the first task is going to be processed then the + # status of the job switches from PENDING to IN_PROGRESS + if job.status == job.Status.PENDING: + job.status = job.Status.IN_PROGRESS + job.started_at = timezone.now() + job.save() + + assert job.status == job.Status.IN_PROGRESS + + # Prepare the task itself + task.status = AnalysisTask.Status.IN_PROGRESS + task.started_at = timezone.now() + task.save() + + try: + self.process_task(task) + + # If the overwritten process_task method changes the status of the + # task itself then we leave it as it is. Otherwise if the status is + # still in progress we set it to success. + if task.status == AnalysisTask.Status.IN_PROGRESS: + task.status = AnalysisTask.Status.SUCCESS + except Exception as err: + logger.exception("Task %s failed.", task) + + task.status = AnalysisTask.Status.FAILURE + task.message = str(err) + if task.log: + task.log += "\n---\n" + task.log += traceback.format_exc() + finally: + logger.info("Task %s ended", task) + task.ended_at = timezone.now() + task.save() + job.update_job_state() + + def process_task(self, task: AnalysisTask) -> None: + """The derived class should process the task here.""" + ... diff --git a/radis/core/tasks.py b/radis/core/tasks.py index 6ff0390f..a476ddc6 100644 --- a/radis/core/tasks.py +++ b/radis/core/tasks.py @@ -1,21 +1,15 @@ import logging -import traceback from adit_radis_shared.accounts.models import User -from celery import Task as CeleryTask -from celery import shared_task -from celery.exceptions import Retry from django.conf import settings from django.core.mail import send_mail from django.core.management import call_command -from django.utils import timezone - -from radis.core.models import AnalysisJob, AnalysisTask +from procrastinate.contrib.django import app logger = logging.getLogger(__name__) -@shared_task +@app.task def broadcast_mail(subject: str, message: str): recipients = [] for user in User.objects.all(): @@ -27,128 +21,7 @@ def broadcast_mail(subject: str, message: str): logger.info("Successfully sent an Email to %d recipients.", len(recipients)) -@shared_task -def backup_db(): +@app.periodic(cron="0 3 * * * ") # every day at 3am +@app.task +def backup_db(*args, **kwargs): call_command("backup_db") - - -class ProcessAnalysisTask(CeleryTask): - analysis_task_class: type[AnalysisTask] - - def run(self, task_id: int) -> None: - task = self.analysis_task_class.objects.get(id=task_id) - job = task.job - - logger.info("Start processing task %s", task) - - # Jobs are canceled by the AnalysisJobCancelView and tasks are also revoked there, - # but it could happen that the task was already picked up by a worker or under rare - # circumstances will nevertheless get picked up by a worker (e.g. the worker crashes - # and forgot its revoked tasks). We then just ignore that task. - if ( - job.status == AnalysisJob.Status.CANCELING - or job.status == AnalysisJob.Status.CANCELED - or task.status == AnalysisTask.Status.CANCELED - ): - task.status = task.Status.CANCELED - task.started_at = timezone.now() - task.ended_at = timezone.now() - task.save() - job.update_job_state() - return - - assert task.status == task.Status.PENDING - - # When the first task is going to be processed then the - # status of the job switches from PENDING to IN_PROGRESS - if job.status == job.Status.PENDING: - job.status = job.Status.IN_PROGRESS - job.started_at = timezone.now() - job.save() - - assert job.status == job.Status.IN_PROGRESS - - # Prepare the task itself - task.status = AnalysisTask.Status.IN_PROGRESS - task.started_at = timezone.now() - task.save() - - try: - self.process_task(task) - - # If the overwritten process_task method changes the status of the - # task itself then we leave it as it is. Otherwise if the status is - # still in progress we set it to success. - if task.status == AnalysisTask.Status.IN_PROGRESS: - task.status = AnalysisTask.Status.SUCCESS - except Retry as err: - # Subclasses can raise Retry to indicate that the task should be retried. - # This must be passed through to the Celery worker. - - # TODO: How do we handle max retries?! - - if task.status != AnalysisTask.Status.PENDING: - task.status = AnalysisTask.Status.PENDING - - logger.info("Task %s will be retried.", task) - - raise err - except Exception as err: - logger.exception("Task %s failed.", task) - - task.status = AnalysisTask.Status.FAILURE - task.message = str(err) - if task.log: - task.log += "\n---\n" - task.log += traceback.format_exc() - finally: - logger.info("Task %s ended", task) - task.ended_at = timezone.now() - task.save() - job.update_job_state() - - def process_task(self, task: AnalysisTask) -> None: - """The derived class should process the task here.""" - ... - - -class ProcessAnalysisJob(CeleryTask): - analysis_job_class: type[AnalysisJob] - process_analysis_task: ProcessAnalysisTask - task_queue: str - - def run(self, job_id: int) -> None: - job = self.analysis_job_class.objects.get(id=job_id) - logger.info("Start processing job %s", job) - assert job.status == AnalysisJob.Status.PREPARING - - priority = job.default_priority - if job.urgent: - priority = job.urgent_priority - - logger.debug("Collecting tasks for job %s", job) - tasks: list[AnalysisTask] = [] - for task in self.collect_tasks(job): - assert task.status == task.Status.PENDING - tasks.append(task) - - logger.debug("Found %d tasks for job %s", len(tasks), job) - - job.status = AnalysisJob.Status.PENDING - job.save() - - for task in tasks: - result = ( - self.process_analysis_task.s(task_id=task.id) - .set(priority=priority) - .apply_async(queue=self.task_queue) - ) - # Save Celery task ID to analysis task (for revoking it later if necessary). - # Only works when not in eager mode (which is used to debug Celery stuff). - if not getattr(settings, "CELERY_TASK_ALWAYS_EAGER", False): - task.celery_task_id = result.id - task.save() - - def collect_tasks(self, job: AnalysisJob) -> list[AnalysisTask]: - """The derived class should collect the tasks to process here.""" - ... diff --git a/radis/core/templates/core/admin_section.html b/radis/core/templates/core/admin_section.html index 8dbb5dc2..b48c4893 100644 --- a/radis/core/templates/core/admin_section.html +++ b/radis/core/templates/core/admin_section.html @@ -19,12 +19,6 @@
Admin Tools
  • Django Admin
  • -
  • - Flower (Celery Monitoring) -
  • -
  • - RabbitMQ Management Console -
  • {% endblock content %} diff --git a/radis/core/templates/core/mail/analysis_job_failed.html b/radis/core/templates/core/mail/analysis_job_failed.html index 5c0a1ba5..879c73e7 100644 --- a/radis/core/templates/core/mail/analysis_job_failed.html +++ b/radis/core/templates/core/mail/analysis_job_failed.html @@ -1,2 +1 @@ {{ job }} failed unexpectedly. -Celery Task ID {{ celery_task_id }}. diff --git a/radis/core/urls.py b/radis/core/urls.py index c2a6be79..8404fc9a 100644 --- a/radis/core/urls.py +++ b/radis/core/urls.py @@ -1,10 +1,8 @@ -from adit_radis_shared.common.views import FlowerProxyView from django.urls import path from .views import ( BroadcastView, HomeView, - RabbitManagementProxyView, UpdatePreferencesView, admin_section, ) @@ -29,6 +27,4 @@ HomeView.as_view(), name="home", ), - FlowerProxyView.as_url(), - RabbitManagementProxyView.as_url(), ] diff --git a/radis/core/views.py b/radis/core/views.py index e8b4c2d8..7c4097c7 100644 --- a/radis/core/views.py +++ b/radis/core/views.py @@ -4,12 +4,10 @@ from adit_radis_shared.common.site import THEME_PREFERENCE_KEY from adit_radis_shared.common.types import AuthenticatedHttpRequest from adit_radis_shared.common.views import ( - AdminProxyView, BaseBroadcastView, BaseHomeView, BaseUpdatePreferencesView, ) -from django.conf import settings from django.contrib import messages from django.contrib.admin.views.decorators import staff_member_required from django.contrib.auth.mixins import ( @@ -29,8 +27,8 @@ from django_filters.filterset import FilterSet from django_filters.views import FilterView from django_tables2 import SingleTableMixin, Table +from procrastinate.contrib.django import app -from radis.celery import app as celery_app from radis.core.utils.model_utils import reset_tasks from .models import AnalysisJob, AnalysisTask @@ -46,7 +44,7 @@ class BroadcastView(BaseBroadcastView): success_url = reverse_lazy("broadcast") def send_mails(self, subject: str, message: str) -> None: - broadcast_mail.delay(subject, message) + broadcast_mail.defer(subject, message) class HomeView(BaseHomeView): @@ -204,12 +202,11 @@ def post(self, request: AuthenticatedHttpRequest, *args, **kwargs) -> HttpRespon ) tasks = job.tasks.filter(status=AnalysisTask.Status.PENDING) - tasks.update(status=AnalysisTask.Status.CANCELED) for task in tasks.only("celery_task_id"): - if task.celery_task_id: - # Cave, can only revoke tasks that are not already fetched by a worker. - # So the worker will check again each task if it was cancelled. - celery_app.control.revoke(task.celery_task_id) + queued_job_id = task.queued_job_id + if queued_job_id is not None: + app.job_manager.cancel_job_by_id(queued_job_id, delete_job=True) + tasks.update(status=AnalysisTask.Status.CANCELED) if job.tasks.filter(status=AnalysisTask.Status.IN_PROGRESS).exists(): job.status = AnalysisJob.Status.CANCELING @@ -384,11 +381,3 @@ def post(self, request: AuthenticatedHttpRequest, *args, **kwargs) -> HttpRespon messages.success(request, self.success_message % task.__dict__) return redirect(task) - - -class RabbitManagementProxyView(AdminProxyView): - upstream = ( - f"http://{settings.RABBIT_MANAGEMENT_HOST}:" f"{settings.RABBIT_MANAGEMENT_PORT}" # type: ignore - ) - url_prefix = "rabbit" - rewrite = ((rf"^/{url_prefix}$", r"/"),) diff --git a/radis/rag/migrations/0012_switch_to_procrastinate.py b/radis/rag/migrations/0012_switch_to_procrastinate.py new file mode 100644 index 00000000..fc37a922 --- /dev/null +++ b/radis/rag/migrations/0012_switch_to_procrastinate.py @@ -0,0 +1,44 @@ +# Generated by Django 5.0.7 on 2024-07-21 12:54 + +import django.db.models.deletion +from django.db import migrations, models + +from adit_radis_shared.common.utils.migration_utils import procrastinate_on_delete_sql + + +class Migration(migrations.Migration): + + dependencies = [ + ('procrastinate', '0028_add_cancel_states'), + ('rag', '0011_remove_questionresult_task_and_more'), + ] + + operations = [ + migrations.RenameField( + model_name='ragtask', + old_name='retries', + new_name='attempts', + ), + migrations.RemoveField( + model_name='ragtask', + name='celery_task_id', + ), + migrations.AddField( + model_name='ragjob', + name='queued_job', + field=models.OneToOneField(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='procrastinate.procrastinatejob'), + ), + migrations.AddField( + model_name='ragtask', + name='queued_job', + field=models.OneToOneField(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='procrastinate.procrastinatejob'), + ), + migrations.RunSQL( + sql=procrastinate_on_delete_sql("rag", "ragjob"), + reverse_sql=procrastinate_on_delete_sql("rag", "ragjob", reverse=True), + ), + migrations.RunSQL( + sql=procrastinate_on_delete_sql("rag", "ragtask"), + reverse_sql=procrastinate_on_delete_sql("rag", "ragtask", reverse=True), + ), + ] diff --git a/radis/rag/models.py b/radis/rag/models.py index 88181de8..e8dbdd74 100644 --- a/radis/rag/models.py +++ b/radis/rag/models.py @@ -1,11 +1,12 @@ from typing import TYPE_CHECKING, Callable from adit_radis_shared.common.models import AppSettings -from celery import current_app from django.conf import settings from django.contrib.auth.models import Group from django.db import models from django.urls import reverse +from procrastinate.contrib.django import app +from procrastinate.contrib.django.models import ProcrastinateJob from radis.core.models import AnalysisJob, AnalysisTask from radis.reports.models import Language, Modality, Report @@ -25,9 +26,11 @@ class RagJob(AnalysisJob): urgent_priority = settings.RAG_URGENT_PRIORITY continuous_job = False + queued_job_id: int | None + queued_job = models.OneToOneField( + ProcrastinateJob, null=True, on_delete=models.SET_NULL, related_name="+" + ) title = models.CharField(max_length=100) - "The title of the job that is shown in the job list" - provider = models.CharField(max_length=100) group = models.ForeignKey(Group, on_delete=models.CASCADE) query = models.CharField(max_length=200) @@ -53,7 +56,13 @@ def get_absolute_url(self) -> str: return reverse("rag_job_detail", args=[self.id]) def delay(self) -> None: - current_app.send_task("radis.rag.tasks.ProcessRagJob", args=[self.id]) + queued_job_id = app.configure_task( + "radis.rag.tasks.process_rag_job", + allow_unknown=False, + priority=self.urgent_priority if self.urgent else self.default_priority, + ).defer(job_id=self.id) + self.queued_job_id = queued_job_id + self.save() class Answer(models.TextChoices): @@ -82,7 +91,13 @@ def get_absolute_url(self) -> str: return reverse("rag_task_detail", args=[self.id]) def delay(self) -> None: - current_app.send_task("radis.rag.tasks.ProcessRagTask", args=[self.id]) + queued_job_id = app.configure_task( + "radis.rag.tasks.process_rag_task", + allow_unknown=False, + priority=self.job.urgent_priority if self.job.urgent else self.job.default_priority, + ).defer(task_id=self.id) + self.queued_job_id = queued_job_id + self.save() class RagInstance(models.Model): diff --git a/radis/rag/processors.py b/radis/rag/processors.py new file mode 100644 index 00000000..84f20606 --- /dev/null +++ b/radis/rag/processors.py @@ -0,0 +1,120 @@ +import asyncio +import logging +from asyncio import Semaphore + +from channels.db import database_sync_to_async +from django import db +from django.conf import settings +from django.db.models.query import QuerySet +from pebble import concurrent + +from radis.core.processors import AnalysisTaskProcessor +from radis.core.utils.chat_client import AsyncChatClient +from radis.reports.models import Report + +from .models import Answer, Question, QuestionResult, RagInstance, RagTask + +logger = logging.getLogger(__name__) + + +class RagTaskProcessor(AnalysisTaskProcessor): + def process_task(self, task: RagTask) -> None: + future = self.process_task_in_thread(task) + future.result() + + @concurrent.thread + def process_task_in_thread(self, task: RagTask) -> None: + asyncio.run(self.process_rag_task(task)) + + async def process_rag_task(self, task: RagTask) -> None: + client = AsyncChatClient() + sem = Semaphore(settings.RAG_LLM_CONCURRENCY_LIMIT) + + await asyncio.gather( + *[ + self.process_rag_instance(rag_instance, client, sem) + async for rag_instance in task.rag_instances.prefetch_related("reports") + ] + ) + await database_sync_to_async(db.close_old_connections)() + + async def combine_reports(self, reports: QuerySet[Report]) -> Report: + count = await reports.acount() + if count > 1: + raise ValueError("Multiple reports is not yet supported") + + report = await reports.afirst() + if report is None: + raise ValueError("No reports to combine") + + return report + + async def process_yes_or_no_question( + self, + rag_instance: RagInstance, + body: str, + language: str, + question: Question, + client: AsyncChatClient, + ) -> RagInstance.Result: + llm_answer = await client.ask_yes_no_question(body, language, question.question) + + if llm_answer == "yes": + answer = Answer.YES + elif llm_answer == "no": + answer = Answer.NO + else: + raise ValueError(f"Unexpected answer: {llm_answer}") + + result = ( + RagInstance.Result.ACCEPTED + if question.accepted_answer == answer + else RagInstance.Result.REJECTED + ) + + await QuestionResult.objects.aupdate_or_create( + rag_instance=rag_instance, + question=question, + defaults={ + "original_answer": answer, + "current_answer": answer, + "result": result, + }, + ) + + logger.debug("RAG result for question %s: %s", question, answer) + + return result + + async def process_rag_instance( + self, rag_instance: RagInstance, client: AsyncChatClient, sem: Semaphore + ) -> None: + report = await self.combine_reports(rag_instance.reports.prefetch_related("language")) + language = report.language + + if language.code not in settings.SUPPORTED_LANGUAGES: + raise ValueError(f"Language '{language}' is not supported.") + + async with sem: + results = await asyncio.gather( + *[ + self.process_yes_or_no_question( + rag_instance, report.body, language.code, question, client + ) + async for question in rag_instance.task.job.questions.all() + ] + ) + + if all([result == RagInstance.Result.ACCEPTED for result in results]): + overall_result = RagInstance.Result.ACCEPTED + else: + overall_result = RagInstance.Result.REJECTED + + rag_instance.overall_result = overall_result + await rag_instance.asave() + + logger.info( + "Overall RAG result for for report %s: %s", + rag_instance, + rag_instance.get_overall_result_display(), + ) diff --git a/radis/rag/tasks.py b/radis/rag/tasks.py index d3126547..1f0029aa 100644 --- a/radis/rag/tasks.py +++ b/radis/rag/tasks.py @@ -1,192 +1,74 @@ -import asyncio import logging -from asyncio import Semaphore from itertools import batched -from typing import Iterator, override -from channels.db import database_sync_to_async -from django import db from django.conf import settings -from django.db.models.query import QuerySet +from procrastinate.contrib.django import app -from radis.celery import app as celery_app -from radis.core.tasks import ProcessAnalysisJob, ProcessAnalysisTask -from radis.core.utils.chat_client import AsyncChatClient from radis.reports.models import Report from radis.search.site import Search, SearchFilters from radis.search.utils.query_parser import QueryParser -from .models import Answer, Question, QuestionResult, RagInstance, RagJob, RagTask +from .models import RagInstance, RagJob, RagTask +from .processors import RagTaskProcessor from .site import retrieval_providers logger = logging.getLogger(__name__) -class ProcessRagTask(ProcessAnalysisTask): - analysis_task_class = RagTask - - def __init__(self) -> None: - super().__init__() - - @override - def process_task(self, task: RagTask) -> None: - asyncio.run(self.process_rag_task(task)) - - async def process_rag_task(self, task: RagTask) -> None: - client = AsyncChatClient() - sem = Semaphore(settings.RAG_LLM_CONCURRENCY_LIMIT) - - await asyncio.gather( - *[ - self.process_rag_instance(rag_instance, client, sem) - async for rag_instance in task.rag_instances.prefetch_related("reports") - ] - ) - await database_sync_to_async(db.close_old_connections)() - - async def process_rag_instance( - self, rag_instance: RagInstance, client: AsyncChatClient, sem: Semaphore - ) -> None: - report = await self.combine_reports(rag_instance.reports.prefetch_related("language")) - language = report.language - - if language.code not in settings.SUPPORTED_LANGUAGES: - raise ValueError(f"Language '{language}' is not supported.") - - async with sem: - results = await asyncio.gather( - *[ - self.process_yes_or_no_question( - rag_instance, report.body, language.code, question, client - ) - async for question in rag_instance.task.job.questions.all() - ] - ) - - if all([result == RagInstance.Result.ACCEPTED for result in results]): - overall_result = RagInstance.Result.ACCEPTED - else: - overall_result = RagInstance.Result.REJECTED - - rag_instance.overall_result = overall_result - await rag_instance.asave() - - logger.info( - "Overall RAG result for for report %s: %s", - rag_instance, - rag_instance.get_overall_result_display(), - ) - - async def process_yes_or_no_question( - self, - rag_instance: RagInstance, - body: str, - language: str, - question: Question, - client: AsyncChatClient, - ) -> RagInstance.Result: - llm_answer = await client.ask_yes_no_question(body, language, question.question) - - if llm_answer == "yes": - answer = Answer.YES - elif llm_answer == "no": - answer = Answer.NO - else: - raise ValueError(f"Unexpected answer: {llm_answer}") - - result = ( - RagInstance.Result.ACCEPTED - if question.accepted_answer == answer - else RagInstance.Result.REJECTED - ) - - await QuestionResult.objects.aupdate_or_create( - rag_instance=rag_instance, - question=question, - defaults={ - "original_answer": answer, - "current_answer": answer, - "result": result, - }, - ) - - logger.debug("RAG result for question %s: %s", question, answer) - - return result - - async def combine_reports(self, reports: QuerySet[Report]) -> Report: - count = await reports.acount() - if count > 1: - raise ValueError("Multiple reports is not yet supported") - - report = await reports.afirst() - if report is None: - raise ValueError("No reports to combine") - - return report - - -process_rag_task = ProcessRagTask() - - -celery_app.register_task(process_rag_task) - - -class ProcessRagJob(ProcessAnalysisJob): - analysis_job_class = RagJob - process_analysis_task = process_rag_task - task_queue = "llm_queue" - - @override - def collect_tasks(self, job: RagJob) -> Iterator[RagTask]: - patient_sex = None - if job.patient_sex == "M": - patient_sex = "M" - elif job.patient_sex == "F": - patient_sex = "F" - - provider = job.provider - retrieval_provider = retrieval_providers[provider] - - query_node, fixes = QueryParser().parse(job.query) - - if query_node is None: - raise ValueError(f"Not a valid query (evaluated as empty): {job.query}") - - if len(fixes) > 0: - logger.info(f"The following fixes were applied to the query:\n{"\n - ".join(fixes)}") - - search = Search( - query=query_node, - offset=0, - limit=retrieval_provider.max_results, - filters=SearchFilters( - group=job.group.pk, - language=job.language.code, - modalities=list(job.modalities.values_list("code", flat=True)), - study_date_from=job.study_date_from, - study_date_till=job.study_date_till, - study_description=job.study_description, - patient_sex=patient_sex, - patient_age_from=job.age_from, - patient_age_till=job.age_till, - ), - ) - - logger.debug("Searching reports for task with search: %s", search) - - for document_ids in batched( - retrieval_provider.retrieve(search), settings.RAG_TASK_BATCH_SIZE - ): - logger.debug("Creating RAG task for document IDs: %s", document_ids) - task = RagTask.objects.create(job=job) - for document_id in document_ids: - rag_instance = RagInstance.objects.create(task=task) - rag_instance.reports.add(Report.objects.get(document_id=document_id)) +@app.task(queue="llm") +def process_rag_task(task_id: int) -> None: + task = RagTask.objects.get(id=task_id) + processor = RagTaskProcessor() + processor.start(task) - yield task +@app.task +def process_rag_job(job_id: int) -> None: + job = RagJob.objects.get(id=job_id) -process_rag_job = ProcessRagJob() + logger.info("Start processing job %s", job) + assert job.status == RagJob.Status.PREPARING -celery_app.register_task(process_rag_job) + provider = job.provider + retrieval_provider = retrieval_providers[provider] + + logger.debug("Collecting tasks for job %s", job) + + query_node, fixes = QueryParser().parse(job.query) + + if query_node is None: + raise ValueError(f"Not a valid query (evaluated as empty): {job.query}") + + if len(fixes) > 0: + logger.info(f"The following fixes were applied to the query:\n{"\n - ".join(fixes)}") + + search = Search( + query=query_node, + offset=0, + limit=retrieval_provider.max_results, + filters=SearchFilters( + group=job.group.pk, + language=job.language.code, + modalities=list(job.modalities.values_list("code", flat=True)), + study_date_from=job.study_date_from, + study_date_till=job.study_date_till, + study_description=job.study_description, + patient_sex=job.patient_sex, # type: ignore + patient_age_from=job.age_from, + patient_age_till=job.age_till, + ), + ) + + logger.debug("Searching reports for task with search: %s", search) + + for document_ids in batched(retrieval_provider.retrieve(search), settings.RAG_TASK_BATCH_SIZE): + logger.debug("Creating RAG task for document IDs: %s", document_ids) + task = RagTask.objects.create(job=job, status=RagTask.Status.PENDING) + for document_id in document_ids: + rag_instance = RagInstance.objects.create(task=task) + rag_instance.reports.add(Report.objects.get(document_id=document_id)) + + task.delay() + + job.status = RagJob.Status.PENDING + job.save() diff --git a/radis/rag/templates/rag/rag_task_detail.html b/radis/rag/templates/rag/rag_task_detail.html index 1a8edb3b..b8ea4c6a 100644 --- a/radis/rag/templates/rag/rag_task_detail.html +++ b/radis/rag/templates/rag/rag_task_detail.html @@ -56,10 +56,6 @@

    RAG Task

    {{ task.ended_at|default:"—" }} {% if user.is_staff %} -
    Celery Task ID
    -
    - {{ task.celery_task_id|default:"—" }} -
    Log
    {{ task.log|default:"—" }}
    diff --git a/radis/rag/tests/unit/test_task.py b/radis/rag/tests/unit/test_task.py index d9d85618..dd2c3354 100644 --- a/radis/rag/tests/unit/test_task.py +++ b/radis/rag/tests/unit/test_task.py @@ -3,11 +3,11 @@ import pytest from radis.rag.models import Answer, RagInstance -from radis.rag.tasks import ProcessRagTask +from radis.rag.processors import RagTaskProcessor @pytest.mark.django_db(transaction=True) -def test_process_rag_task(create_rag_task, openai_chat_completions_mock, mocker): +def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocker): num_rag_instances = 5 num_questions = 5 rag_task = create_rag_task( @@ -18,12 +18,12 @@ def test_process_rag_task(create_rag_task, openai_chat_completions_mock, mocker) ) openai_mock = openai_chat_completions_mock("Yes") - process_rag_task_spy = mocker.spy(ProcessRagTask, "process_rag_task") - process_rag_instance_spy = mocker.spy(ProcessRagTask, "process_rag_instance") - process_yes_or_no_question_spy = mocker.spy(ProcessRagTask, "process_yes_or_no_question") + process_rag_task_spy = mocker.spy(RagTaskProcessor, "process_rag_task") + process_rag_instance_spy = mocker.spy(RagTaskProcessor, "process_rag_instance") + process_yes_or_no_question_spy = mocker.spy(RagTaskProcessor, "process_yes_or_no_question") with patch("openai.AsyncOpenAI", return_value=openai_mock): - ProcessRagTask().process_task(rag_task) + RagTaskProcessor().start(rag_task) rag_instances = rag_task.rag_instances.all() for instance in rag_instances: diff --git a/radis/settings/base.py b/radis/settings/base.py index f430bebd..fa77854f 100644 --- a/radis/settings/base.py +++ b/radis/settings/base.py @@ -20,6 +20,9 @@ # The base directory of the project (the root of the repository) BASE_DIR = Path(__file__).resolve(strict=True).parent.parent.parent +# Used to monitor for autoreload +SOURCE_FOLDERS = [BASE_DIR / "radis"] + # Read pyproject.toml to fetch current version. We do this conditionally as the # RADIS client library uses RADIS for integration tests installed as a package # (where no pyproject.toml is available). @@ -57,6 +60,7 @@ "django.contrib.humanize", "django.contrib.postgres", "django_extensions", + "procrastinate.contrib.django", "dbbackup", "revproxy", "loginas", @@ -198,11 +202,6 @@ "level": "INFO", "propagate": False, }, - "celery": { - "handlers": ["console", "mail_admins"], - "level": "INFO", - "propagate": False, - }, "django": { "handlers": ["console"], "level": "WARNING", @@ -298,46 +297,6 @@ # Channels ASGI_APPLICATION = "radis.asgi.application" -# RabbitMQ is used as Celery message broker -RABBITMQ_URL = env.str("RABBITMQ_URL", default="amqp://localhost") # type: ignore - -# Rabbit Management console is integrated in RADIS by using a reverse -# proxy (django-revproxy).This allows us to use the authentication of RADIS. -# But as RabbitMQ authentication can't be disabled, so we have to login -# there with "guest" as username and password again. -RABBIT_MANAGEMENT_HOST = env.str("RABBIT_MANAGEMENT_HOST", default="localhost") # type: ignore -RABBIT_MANAGEMENT_PORT = env.int("RABBIT_MANAGEMENT_PORT", default=15672) # type: ignore - -# Celery -# see https://github.com/celery/celery/issues/5026 for how to name configs -if USE_TZ: - CELERY_TIMEZONE = TIME_ZONE -CELERY_BROKER_URL = RABBITMQ_URL -CELERY_WORKER_HIJACK_ROOT_LOGGER = False -CELERY_IGNORE_RESULT = True -CELERY_TASK_DEFAULT_QUEUE = "default_queue" -CELERY_BEAT_SCHEDULE = {} - -# Settings for priority queues, see also apply_async calls in the models. -# Requires RabbitMQ as the message broker! -CELERY_TASK_QUEUE_MAX_PRIORITY = 10 -CELERY_TASK_DEFAULT_PRIORITY = 5 - -# Only non prefetched tasks can be sorted by their priority. So we prefetch -# only one task for each availalbe child process. The number of child processes -# can be set with the -c parameter when starting the worker. -CELERY_WORKER_PREFETCH_MULTIPLIER = 1 - -# Only acknowledge the Celery task when it was finished by the worker. -# If the worker crashed while executing the task it will be re-executed -# when the worker is up again -CELERY_TASK_ACKS_LATE = True - -# Flower is integrated in RADIS by using a reverse proxy (django-revproxy). -# This allows to use the authentication of RADIS. -FLOWER_HOST = env.str("FLOWER_HOST", default="localhost") # type: ignore -FLOWER_PORT = env.int("FLOWER_PORT", default=5555) # type: ignore - # Used by django-filter FILTERS_EMPTY_CHOICE_LABEL = "Show All" diff --git a/radis/settings/development.py b/radis/settings/development.py index 74662074..c3de5c87 100644 --- a/radis/settings/development.py +++ b/radis/settings/development.py @@ -40,9 +40,6 @@ DEBUG_TOOLBAR_CONFIG = {"SHOW_TOOLBAR_CALLBACK": lambda request: settings.DEBUG} -CELERY_TASK_ALWAYS_EAGER = False -CELERY_TASK_EAGER_PROPAGATES = False - LOGGING["loggers"]["radis"]["level"] = "DEBUG" # noqa: F405 INTERNAL_IPS = env.list("DJANGO_INTERNAL_IPS", default=["127.0.0.1"]) # type: ignore diff --git a/radis/settings/production.py b/radis/settings/production.py index abf6adf8..92fb51c8 100644 --- a/radis/settings/production.py +++ b/radis/settings/production.py @@ -1,5 +1,3 @@ -from celery.schedules import crontab - from .base import * # noqa: F403 from .base import env @@ -23,8 +21,3 @@ EMAIL_HOST_USER = env.str("DJANGO_EMAIL_HOST_USER", default="") # type: ignore EMAIL_HOST_PASSWORD = env.str("DJANGO_EMAIL_HOST_PASSWORD", default="") # type: ignore EMAIL_USE_TLS = env.bool("DJANGO_EMAIL_USE_TLS", default=False) # type: ignore - -CELERY_BEAT_SCHEDULE["backup-db"] = { # noqa: F405 - "task": "radis.core.tasks.backup_db", - "schedule": crontab(minute=0, hour=3), # execute daily at 3 o'clock UTC -} diff --git a/radis/settings/test.py b/radis/settings/test.py index 70ec6344..5668b14b 100644 --- a/radis/settings/test.py +++ b/radis/settings/test.py @@ -1,16 +1,10 @@ from .development import * # noqa: F403 -# We must force the Celery test worker (in a subprocess started inside a test) to -# use the test database. +# We must force our background worker that is started while testing +# in a subprocess to use the test database. if not DATABASES["default"]["NAME"].startswith("test_"): # noqa: F405 test_database = "test_" + DATABASES["default"]["NAME"] # noqa: F405 DATABASES["default"]["NAME"] = test_database # noqa: F405 DATABASES["default"]["TEST"] = {"NAME": test_database} # noqa: F405 -# This test worker uses a "test_queue" (see radis_celery_worker fixture). In contrast -# to development and production system we only use one worker that handles all -# Celery tasks. -CELERY_TASK_DEFAULT_QUEUE = "test_queue" -CELERY_TASK_ROUTES = {} - DEBUG_TOOLBAR_CONFIG = {"SHOW_TOOLBAR_CALLBACK": lambda request: False} diff --git a/tasks.py b/tasks.py index 42def502..4b6b3c9f 100644 --- a/tasks.py +++ b/tasks.py @@ -393,20 +393,6 @@ def try_github_actions(ctx: Context): ctx.run(f"{act_path} -P ubuntu-latest=catthehacker/ubuntu:act-latest", pty=True) -@task -def purge_celery( - ctx: Context, - env: Environments = "dev", - queues: str = "default_queue", - force=False, -): - """Purge Celery queues""" - cmd = f"{build_compose_cmd(env)} exec web celery -A radis purge -Q {queues}" - if force: - cmd += " -f" - ctx.run(cmd, pty=True) - - @task def backup_db(ctx: Context, env: Environments = "prod"): """Backup database From 1c7b5daee3f5ff13caf874ceec302826a335db05 Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Sun, 21 Jul 2024 15:15:17 +0000 Subject: [PATCH 2/8] Rename web folder --- Dockerfile | 12 ++++++------ compose/docker-compose.base.yml | 6 +++--- compose/docker-compose.prod.yml | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Dockerfile b/Dockerfile index 17cc44e3..c3c50d24 100644 --- a/Dockerfile +++ b/Dockerfile @@ -72,9 +72,9 @@ RUN poetry install RUN playwright install --with-deps chromium # Required folders for RADIS -RUN mkdir -p /var/www/radis/logs \ - /var/www/radis/static \ - /var/www/radis/ssl +RUN mkdir -p /var/www/web/logs \ + /var/www/web/static \ + /var/www/web/ssl # will become mountpoint of our code WORKDIR /app @@ -86,8 +86,8 @@ COPY --from=builder-base $PYSETUP_PATH $PYSETUP_PATH COPY . /app/ # Required folders for RADIS -RUN mkdir -p /var/www/radis/logs \ - /var/www/radis/static \ - /var/www/radis/ssl +RUN mkdir -p /var/www/web/logs \ + /var/www/web/static \ + /var/www/web/ssl WORKDIR /app diff --git a/compose/docker-compose.base.yml b/compose/docker-compose.base.yml index 0efa8201..25e5820a 100644 --- a/compose/docker-compose.base.yml +++ b/compose/docker-compose.base.yml @@ -5,7 +5,7 @@ x-app: &default-app - postgres environment: USE_DOCKER: 1 - DJANGO_STATIC_ROOT: "/var/www/radis/static/" + DJANGO_STATIC_ROOT: "/var/www/web/static/" DATABASE_URL: "psql://postgres:postgres@postgres.local:5432/postgres" LLAMACPP_URL: "http://llamacpp.local:8080" @@ -14,7 +14,7 @@ services: <<: *default-app hostname: init.local volumes: - - web_data:/var/www/radis + - web_data:/var/www/web - /mnt:/mnt web: @@ -23,7 +23,7 @@ services: build: context: .. volumes: - - web_data:/var/www/radis + - web_data:/var/www/web - /mnt:/mnt worker_default: diff --git a/compose/docker-compose.prod.yml b/compose/docker-compose.prod.yml index af6cce4c..401b36eb 100644 --- a/compose/docker-compose.prod.yml +++ b/compose/docker-compose.prod.yml @@ -5,8 +5,8 @@ x-app: &default-app environment: ENABLE_REMOTE_DEBUGGING: 0 DJANGO_SETTINGS_MODULE: "radis.settings.production" - SSL_CERT_FILE: "/var/www/radis/ssl/cert.pem" - SSL_KEY_FILE: "/var/www/radis/ssl/key.pem" + SSL_CERT_FILE: "/var/www/web/ssl/cert.pem" + SSL_KEY_FILE: "/var/www/web/ssl/key.pem" x-deploy: &deploy replicas: 1 @@ -44,7 +44,7 @@ services: bash -c " wait-for-it -s init.local:8000 -t 300 && echo 'Starting web server ...' - daphne -b 0.0.0.0 -p 80 -e ssl:443:privateKey=/var/www/radis/ssl/key.pem:certKey=/var/www/radis/ssl/cert.pem radis.asgi:application + daphne -b 0.0.0.0 -p 80 -e ssl:443:privateKey=/var/www/web/ssl/key.pem:certKey=/var/www/web/ssl/cert.pem radis.asgi:application " deploy: <<: *deploy From 04e23b21ce7e8c912b6e3036fb7b11782e15411c Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Mon, 22 Jul 2024 22:24:07 +0000 Subject: [PATCH 3/8] Use attempts instead of retries --- radis/rag/templates/rag/rag_task_detail.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/radis/rag/templates/rag/rag_task_detail.html b/radis/rag/templates/rag/rag_task_detail.html index b8ea4c6a..4810cdff 100644 --- a/radis/rag/templates/rag/rag_task_detail.html +++ b/radis/rag/templates/rag/rag_task_detail.html @@ -29,9 +29,9 @@

    RAG Task

    {{ task.get_status_display }}
    -
    Retries
    +
    Attempts
    - {{ task.retries|default:"—" }} + {{ task.attempts|default:"—" }}
    Message
    From 1b263652edf63d46f99ac0eed0dbf5ece3aa395e Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Mon, 22 Jul 2024 22:36:21 +0000 Subject: [PATCH 4/8] Fix RAG pipeline --- radis/core/processors.py | 19 ++++++---- radis/rag/processors.py | 77 ++++++++++++++++++---------------------- radis/rag/tasks.py | 8 ++--- 3 files changed, 50 insertions(+), 54 deletions(-) diff --git a/radis/core/processors.py b/radis/core/processors.py index 76ff6879..99331c17 100644 --- a/radis/core/processors.py +++ b/radis/core/processors.py @@ -1,6 +1,7 @@ import logging import traceback +from channels.db import database_sync_to_async from django.utils import timezone from .models import AnalysisJob, AnalysisTask @@ -9,7 +10,11 @@ class AnalysisTaskProcessor: - def start(self, task: AnalysisTask) -> None: + def __init__(self, task: AnalysisTask) -> None: + self.task = task + + async def start(self) -> None: + task = self.task job = task.job logger.info("Start processing task %s", task) @@ -37,17 +42,17 @@ def start(self, task: AnalysisTask) -> None: if job.status == job.Status.PENDING: job.status = job.Status.IN_PROGRESS job.started_at = timezone.now() - job.save() + await job.asave() assert job.status == job.Status.IN_PROGRESS # Prepare the task itself task.status = AnalysisTask.Status.IN_PROGRESS task.started_at = timezone.now() - task.save() + await task.asave() try: - self.process_task(task) + await self.process_task(task) # If the overwritten process_task method changes the status of the # task itself then we leave it as it is. Otherwise if the status is @@ -65,9 +70,9 @@ def start(self, task: AnalysisTask) -> None: finally: logger.info("Task %s ended", task) task.ended_at = timezone.now() - task.save() - job.update_job_state() + await task.asave() + await database_sync_to_async(job.update_job_state)() - def process_task(self, task: AnalysisTask) -> None: + async def process_task(self, task: AnalysisTask) -> None: """The derived class should process the task here.""" ... diff --git a/radis/rag/processors.py b/radis/rag/processors.py index 84f20606..14fff2c6 100644 --- a/radis/rag/processors.py +++ b/radis/rag/processors.py @@ -6,7 +6,6 @@ from django import db from django.conf import settings from django.db.models.query import QuerySet -from pebble import concurrent from radis.core.processors import AnalysisTaskProcessor from radis.core.utils.chat_client import AsyncChatClient @@ -18,15 +17,7 @@ class RagTaskProcessor(AnalysisTaskProcessor): - def process_task(self, task: RagTask) -> None: - future = self.process_task_in_thread(task) - future.result() - - @concurrent.thread - def process_task_in_thread(self, task: RagTask) -> None: - asyncio.run(self.process_rag_task(task)) - - async def process_rag_task(self, task: RagTask) -> None: + async def process_task(self, task: RagTask) -> None: client = AsyncChatClient() sem = Semaphore(settings.RAG_LLM_CONCURRENCY_LIMIT) @@ -38,6 +29,39 @@ async def process_rag_task(self, task: RagTask) -> None: ) await database_sync_to_async(db.close_old_connections)() + async def process_rag_instance( + self, rag_instance: RagInstance, client: AsyncChatClient, sem: Semaphore + ) -> None: + report = await self.combine_reports(rag_instance.reports.prefetch_related("language")) + language = report.language + + if language.code not in settings.SUPPORTED_LANGUAGES: + raise ValueError(f"Language '{language}' is not supported.") + + async with sem: + results = await asyncio.gather( + *[ + self.process_yes_or_no_question( + rag_instance, report.body, language.code, question, client + ) + async for question in rag_instance.task.job.questions.all() + ] + ) + + if all([result == RagInstance.Result.ACCEPTED for result in results]): + overall_result = RagInstance.Result.ACCEPTED + else: + overall_result = RagInstance.Result.REJECTED + + rag_instance.overall_result = overall_result + await rag_instance.asave() + + logger.info( + "Overall RAG result for for report %s: %s", + rag_instance, + rag_instance.get_overall_result_display(), + ) + async def combine_reports(self, reports: QuerySet[Report]) -> Report: count = await reports.acount() if count > 1: @@ -85,36 +109,3 @@ async def process_yes_or_no_question( logger.debug("RAG result for question %s: %s", question, answer) return result - - async def process_rag_instance( - self, rag_instance: RagInstance, client: AsyncChatClient, sem: Semaphore - ) -> None: - report = await self.combine_reports(rag_instance.reports.prefetch_related("language")) - language = report.language - - if language.code not in settings.SUPPORTED_LANGUAGES: - raise ValueError(f"Language '{language}' is not supported.") - - async with sem: - results = await asyncio.gather( - *[ - self.process_yes_or_no_question( - rag_instance, report.body, language.code, question, client - ) - async for question in rag_instance.task.job.questions.all() - ] - ) - - if all([result == RagInstance.Result.ACCEPTED for result in results]): - overall_result = RagInstance.Result.ACCEPTED - else: - overall_result = RagInstance.Result.REJECTED - - rag_instance.overall_result = overall_result - await rag_instance.asave() - - logger.info( - "Overall RAG result for for report %s: %s", - rag_instance, - rag_instance.get_overall_result_display(), - ) diff --git a/radis/rag/tasks.py b/radis/rag/tasks.py index 1f0029aa..be8e98e1 100644 --- a/radis/rag/tasks.py +++ b/radis/rag/tasks.py @@ -16,10 +16,10 @@ @app.task(queue="llm") -def process_rag_task(task_id: int) -> None: - task = RagTask.objects.get(id=task_id) - processor = RagTaskProcessor() - processor.start(task) +async def process_rag_task(task_id: int) -> None: + task = await RagTask.objects.prefetch_related("job").aget(id=task_id) + processor = RagTaskProcessor(task) + await processor.start() @app.task From 1ffa0c18d42dacd79517c64b8d84fddeac1b24c9 Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Tue, 23 Jul 2024 20:14:50 +0000 Subject: [PATCH 5/8] Cleanup pyinvoke tasks --- poetry.lock | 8 +- pyproject.toml | 2 +- radis/conftest.py | 3 +- tasks.py | 444 +++------------------------------------------- 4 files changed, 35 insertions(+), 422 deletions(-) diff --git a/poetry.lock b/poetry.lock index 84343667..7ea70183 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,7 +2,7 @@ [[package]] name = "adit-radis-shared" -version = "0.5.0" +version = "0.6.1" description = "Shared Django apps between ADIT and RADIS" optional = false python-versions = "^3.11" @@ -37,8 +37,8 @@ whitenoise = "^6.0.0" [package.source] type = "git" url = "https://github.com/openradx/adit-radis-shared.git" -reference = "v0.5.0" -resolved_reference = "303374246d1af2e1d66957a13a874a4195ed3877" +reference = "v0.6.1" +resolved_reference = "f6454b54b5233fba5dba82e53d1ed429fa21ee69" [[package]] name = "adrf" @@ -3555,4 +3555,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<4.0" -content-hash = "50fd8cdea72e1e9c433b305bddf1b0e6426f5783c4949df6f4dc80c42ead0160" +content-hash = "0a50389203dfa6ef8ba40e8ada2a79a31a12fe29cc74395bef94b5937b459c8e" diff --git a/pyproject.toml b/pyproject.toml index 6675657a..8bf5ad13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = ["medihack "] license = "GPL-3.0-or-later" [tool.poetry.dependencies] -adit-radis-shared = { git = "https://github.com/openradx/adit-radis-shared.git", tag = "v0.5.0" } +adit-radis-shared = {git = "https://github.com/openradx/adit-radis-shared.git", rev = "v0.6.1"} adrf = "^0.1.4" aiofiles = "^23.1.0" asyncinotify = "^4.0.1" diff --git a/radis/conftest.py b/radis/conftest.py index 64c56afb..600eece8 100644 --- a/radis/conftest.py +++ b/radis/conftest.py @@ -1,5 +1,6 @@ import nest_asyncio -from adit_radis_shared.conftest import * # noqa: F403 + +pytest_plugins = ["adit_radis_shared.pytest_fixtures"] def pytest_configure(): diff --git a/tasks.py b/tasks.py index 4b6b3c9f..2b3ca28b 100644 --- a/tasks.py +++ b/tasks.py @@ -1,439 +1,51 @@ -import os -import shutil -import sys -from os import environ from pathlib import Path -from typing import Literal -from dotenv import set_key +from adit_radis_shared import invoke_tasks +from adit_radis_shared.invoke_tasks import ( # noqa: F401 + backup_db, + bump_version, + format, + init_workspace, + lint, + reset_dev, + restore_db, + show_outdated, + stack_deploy, + stack_rm, + test, + try_github_actions, + upgrade_adit_radis_shared, + upgrade_postgresql, + web_shell, +) from invoke.context import Context from invoke.tasks import task -Environments = Literal["dev", "prod"] - -stack_name_dev = "radis_dev" -stack_name_prod = "radis_prod" - -postgres_dev_volume = f"{stack_name_dev}_postgres_data" -postgres_prod_volume = f"{stack_name_prod}_postgres_data" - -project_dir = Path(__file__).resolve().parent -compose_dir = project_dir / "compose" -models_dir = project_dir / "models" - -compose_file_base = compose_dir / "docker-compose.base.yml" -compose_file_dev = compose_dir / "docker-compose.dev.yml" -compose_file_prod = compose_dir / "docker-compose.prod.yml" - -### -# Helper functions -### - - -def get_stack_name(env: Environments): - if env == "dev": - return stack_name_dev - elif env == "prod": - return stack_name_prod - else: - raise ValueError(f"Unknown environment: {env}") - - -def get_postgres_volume(env: Environments): - if env == "dev": - return postgres_dev_volume - elif env == "prod": - return postgres_prod_volume - else: - raise ValueError(f"Unknown environment: {env}") - - -def build_compose_cmd(env: Environments): - base_compose_cmd = f"docker compose -f '{compose_file_base}'" - stack_name = get_stack_name(env) - if env == "dev": - return f"{base_compose_cmd} -f '{compose_file_dev}' -p {stack_name}" - elif env == "prod": - return f"{base_compose_cmd} -f '{compose_file_prod}' -p {stack_name}" - else: - raise ValueError(f"Unknown environment: {env}") - - -def check_compose_up(ctx: Context, env: Environments): - stack_name = get_stack_name(env) - result = ctx.run("docker compose ls", hide=True, warn=True) - assert result and result.ok - for line in result.stdout.splitlines(): - if line.startswith(stack_name) and line.find("running") != -1: - return True - return False - - -def find_running_container_id(ctx: Context, env: Environments, name: str): - stack_name = get_stack_name(env) - sep = "-" if env == "dev" else "_" - cmd = f"docker ps -q -f name={stack_name}{sep}{name} -f status=running" - cmd += " | head -n1" - result = ctx.run(cmd, hide=True, warn=True) - if result and result.ok: - container_id = result.stdout.strip() - if container_id: - return container_id - return None - - -def confirm(question: str) -> bool: - valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} - while True: - sys.stdout.write(f"{question} [y/N] ") - choice = input().lower() - if choice == "": - return False - elif choice in valid: - return valid[choice] - else: - sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n") - - -### -# Tasks -### - - -@task -def compose_build(ctx: Context, env: Environments = "dev"): - """Build RADIS image for specified environment""" - cmd = f"{build_compose_cmd(env)} build" - ctx.run(cmd, pty=True) +invoke_tasks.PROJECT_NAME = "radis" +invoke_tasks.PROJECT_DIR = Path(__file__).resolve().parent @task def compose_up( ctx: Context, - env: Environments = "dev", + env: invoke_tasks.Environments = "dev", no_build: bool = False, gpu: bool = False, ): - """Start RADIS containers in specified environment""" - profiles: list[str] = [] - + """Start containers in specified environment""" if gpu: - profiles.append("gpu") + profiles = ["gpu"] else: - profiles.append("cpu") + profiles = ["cpu"] - cmd = build_compose_cmd(env) - cmd += "".join(f" --profile {profile}" for profile in profiles) - - build_opt = "--no-build" if no_build else "--build" - cmd += f" up {build_opt} --detach" - - ctx.run(cmd, pty=True) + invoke_tasks.compose_up(ctx=ctx, env=env, no_build=no_build, profile=profiles) @task def compose_down( ctx: Context, - env: Environments = "dev", + env: invoke_tasks.Environments = "dev", cleanup: bool = False, ): - """Stop RADIS containers in specified environment""" - - cmd = "" - cmd += build_compose_cmd(env) - - profiles = ["cpu", "gpu"] - cmd += "".join(f" --profile {profile}" for profile in profiles) - - cmd += " down" - - if cleanup: - cmd += " --remove-orphans --volumes" - - if env == "prod": - cancelled_msg = "Compose down cancelled!" - - response = input("Are you sure to stop the production containers? (yes to proceed) ") - if response != "yes": - print(cancelled_msg) - return - - response = input("Are you sure to delete all production volumes? (yes to proceed) ") - if response != "yes": - print(cancelled_msg) - return - - ctx.run(cmd, pty=True) - - -@task -def compose_restart(ctx: Context, env: Environments = "dev", service: str | None = None): - """Restart RADIS containers in specified environment""" - cmd = f"{build_compose_cmd(env)} restart" - if service: - cmd += f" {service}" - ctx.run(cmd, pty=True) - - -@task -def compose_logs( - ctx: Context, - env: Environments = "dev", - service: str | None = None, - follow: bool = False, - since: str | None = None, - until: str | None = None, - tail: int | None = None, -): - """Show logs of RADIS containers in specified environment""" - cmd = f"{build_compose_cmd(env)} logs" - if service: - cmd += f" {service}" - if follow: - cmd += " --follow" - if since: - cmd += f" --since {since}" - if until: - cmd += f" --until {until}" - if tail: - cmd += f" --tail {tail}" - ctx.run(cmd, pty=True) - - -@task -def stack_deploy(ctx: Context, env: Environments = "prod", build: bool = False): - """Deploy the stack to Docker Swarm (prod by default!). Optional build it before.""" - if build: - compose_build(ctx, env) - - stack_name = get_stack_name(env) - suffix = f"-c {compose_file_base}" - if env == "dev": - suffix += f" -c {compose_file_dev} {stack_name}" - elif env == "prod": - suffix += f" -c {compose_file_prod} {stack_name}" - else: - raise ValueError(f"Unknown environment: {env}") - - cmd = f"docker stack deploy {suffix}" - ctx.run(cmd, pty=True) - - -@task -def stack_rm(ctx: Context, env: Environments = "prod"): - """Remove the stack from Docker Swarm (prod by default!).""" - stack_name = get_stack_name(env) - cmd = f"docker stack rm {stack_name}" - ctx.run(cmd, pty=True) - - -@task -def format(ctx: Context): - """Format the source code with ruff and djlint""" - # Format Python code - format_code_cmd = "poetry run ruff format ." - ctx.run(format_code_cmd, pty=True) - # Sort Python imports - sort_imports_cmd = "poetry run ruff check . --fix --select I" - ctx.run(sort_imports_cmd, pty=True) - # Format Django templates - format_templates_cmd = "poetry run djlint . --reformat" - ctx.run(format_templates_cmd, pty=True) - - -@task -def lint(ctx: Context): - """Lint the source code (ruff, djlint, pyright)""" - cmd_ruff = "poetry run ruff check ." - ctx.run(cmd_ruff, pty=True) - cmd_djlint = "poetry run djlint . --lint" - ctx.run(cmd_djlint, pty=True) - cmd_pyright = "poetry run pyright" - ctx.run(cmd_pyright, pty=True) - - -@task -def test( - ctx: Context, - path: str | None = None, - cov: bool | str = False, - html: bool = False, - keyword: str | None = None, - mark: str | None = None, - stdout: bool = False, - failfast: bool = False, -): - """Run the test suite""" - if not check_compose_up(ctx, "dev"): - sys.exit( - "Integration tests need RADIS dev containers running.\nRun 'invoke compose-up' first." - ) - - cmd = ( - f"{build_compose_cmd('dev')} exec " - "--env DJANGO_SETTINGS_MODULE=radis.settings.test web pytest " - ) - if cov: - cmd += "--cov " - if isinstance(cov, str): - cmd += f"={cov} " - if html: - cmd += "--cov-report=html" - if keyword: - cmd += f"-k {keyword} " - if mark: - cmd += f"-m {mark} " - if stdout: - cmd += "-s " - if failfast: - cmd += "-x " - if path: - cmd += path - ctx.run(cmd, pty=True) - - -@task -def ci(ctx: Context): - """Run the continuous integration (linting and tests)""" - lint(ctx) - test(ctx, cov=True) - - -@task -def reset_dev(ctx: Context): - """Reset dev container environment""" - # Wipe the database - flush_cmd = f"{build_compose_cmd('dev')} exec web python manage.py flush --noinput" - ctx.run(flush_cmd, pty=True) - # Re-populate the database with users and groups - populate_cmd = f"{build_compose_cmd('dev')} exec web python manage.py populate_users_and_groups" - populate_cmd += " --users 20 --groups 3" - ctx.run(populate_cmd, pty=True) - # Re-populate the database with example reports - populate_cmd = f"{build_compose_cmd('dev')} exec web python manage.py populate_reports" - populate_cmd += " --report-language de" - ctx.run(populate_cmd, pty=True) - - -@task -def radis_web_shell(ctx: Context, env: Environments = "dev"): - """Open Python shell in RADIS web container of specified environment""" - cmd = f"{build_compose_cmd(env)} exec web python manage.py shell_plus" - ctx.run(cmd, pty=True) - - -@task -def init_workspace(ctx: Context): - """Initialize workspace for Github Codespaces or Gitpod""" - env_dev_file = f"{project_dir}/.env.dev" - if os.path.isfile(env_dev_file): - print("Workspace already initialized (.env.dev file exists).") - return - - shutil.copy(f"{project_dir}/example.env", env_dev_file) - - def modify_env_file(domain: str | None = None): - if domain: - url = f"https://{domain}" - hosts = f".localhost,127.0.0.1,[::1],{domain}" - set_key(env_dev_file, "DJANGO_CSRF_TRUSTED_ORIGINS", url, quote_mode="never") - set_key(env_dev_file, "DJANGO_ALLOWED_HOSTS", hosts, quote_mode="never") - set_key(env_dev_file, "DJANGO_INTERNAL_IPS", hosts, quote_mode="never") - set_key(env_dev_file, "SITE_BASE_URL", url, quote_mode="never") - set_key(env_dev_file, "SITE_DOMAIN", domain, quote_mode="never") - - set_key(env_dev_file, "FORCE_DEBUG_TOOLBAR", "true", quote_mode="never") - - if environ.get("CODESPACE_NAME"): - # Inside GitHub Codespaces - domain = f"{environ['CODESPACE_NAME']}-8000.preview.app.github.dev" - modify_env_file(domain) - elif environ.get("GITPOD_WORKSPACE_ID"): - # Inside Gitpod - result = ctx.run("gp url 8000", hide=True, pty=True) - assert result and result.ok - domain = result.stdout.strip().removeprefix("https://") - modify_env_file(domain) - else: - # Inside some local environment - modify_env_file() - - -@task -def show_outdated(ctx: Context): - """Show outdated dependencies""" - print("### Outdated Python dependencies ###") - poetry_cmd = "poetry show --outdated --top-level" - result = ctx.run(poetry_cmd, pty=True) - assert result and result.ok - print(result.stderr.strip()) - - print("### Outdated NPM dependencies ###") - npm_cmd = "npm outdated" - ctx.run(npm_cmd, pty=True) - - -@task -def upgrade(ctx: Context): - """Upgrade Python and JS packages""" - ctx.run("poetry update", pty=True) - - -@task -def try_github_actions(ctx: Context): - """Try Github Actions locally using Act""" - act_path = project_dir / "bin" / "act" - if not act_path.exists(): - print("Installing act...") - ctx.run( - "curl https://raw.githubusercontent.com/nektos/act/master/install.sh | sudo bash", - hide=True, - pty=True, - ) - ctx.run(f"{act_path} -P ubuntu-latest=catthehacker/ubuntu:act-latest", pty=True) - - -@task -def backup_db(ctx: Context, env: Environments = "prod"): - """Backup database - - For backup location see setting DBBACKUP_STORAGE_OPTIONS - For possible commands see: - https://django-dbbackup.readthedocs.io/en/master/commands.html - """ - settings = "radis.settings.production" if env == "prod" else "radis.settings.development" - web_container_id = find_running_container_id(ctx, env, "web") - cmd = ( - f"docker exec --env DJANGO_SETTINGS_MODULE={settings} " - f"{web_container_id} ./manage.py dbbackup --clean -v 2" - ) - ctx.run(cmd, pty=True) - - -@task -def restore_db(ctx: Context, env: Environments = "prod"): - """Restore database from backup""" - settings = "radis.settings.production" if env == "prod" else "radis.settings.development" - web_container_id = find_running_container_id(ctx, env, "web") - cmd = ( - f"docker exec --env DJANGO_SETTINGS_MODULE={settings} " - f"{web_container_id} ./manage.py dbrestore" - ) - ctx.run(cmd, pty=True) - - -@task -def upgrade_postgresql(ctx: Context, env: Environments = "dev", version: str = "latest"): - print(f"Upgrading PostgreSQL database in {env} environment to {version}.") - print("Cave, make sure the whole stack is not stopped. Otherwise this will corrupt data!") - if confirm("Are you sure you want to proceed?"): - print("Starting docker container that upgrades the database files.") - print("Watch the output if everything went fine or if any further steps are necessary.") - volume = get_postgres_volume(env) - ctx.run( - f"docker run -e POSTGRES_PASSWORD=postgres -v {volume}:/var/lib/postgresql/data " - f"pgautoupgrade/pgautoupgrade:{version}", - pty=True, - ) - else: - print("Cancelled") + """Stop containers in specified environment""" + invoke_tasks.compose_down(ctx=ctx, env=env, cleanup=cleanup, profile=["cpu", "gpu"]) From 99d59c0c51e9d6722b844c07f4ac7c3cd4044534 Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Tue, 23 Jul 2024 20:17:06 +0000 Subject: [PATCH 6/8] Use RelatedPaginationMixin from adit-radis-shared --- radis/core/mixins.py | 58 -------------------------------------------- radis/rag/views.py | 9 ++++--- 2 files changed, 6 insertions(+), 61 deletions(-) delete mode 100644 radis/core/mixins.py diff --git a/radis/core/mixins.py b/radis/core/mixins.py deleted file mode 100644 index c16a0087..00000000 --- a/radis/core/mixins.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Any, Protocol - -from adit_radis_shared.common.mixins import ViewProtocol -from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator -from django.db.models.query import QuerySet -from django.http import HttpRequest - - -# TODO: Move this to adit_radis_shared package. PR: https://github.com/openradx/adit-radis-shared/pull/5 -class RelatedPaginationMixinProtocol(ViewProtocol, Protocol): - request: HttpRequest - object_list: QuerySet - paginate_by: int - - def get_object(self) -> Any: ... - - def get_context_data(self, **kwargs) -> dict[str, Any]: ... - - def get_related_queryset(self) -> QuerySet: ... - - -class RelatedPaginationMixin: - """This mixin provides pagination for a related queryset. This makes it possible to - paginate a related queryset in a DetailView. The related queryset is obtained by - the `get_related_queryset()` method that must be implemented by the subclass. - If used in combination with `RelatedFilterMixin`, the `RelatedPaginationMixin` must be - inherited first.""" - - def get_related_queryset(self: RelatedPaginationMixinProtocol) -> QuerySet: - raise NotImplementedError("You must implement this method") - - def get_context_data(self: RelatedPaginationMixinProtocol, **kwargs): - context = super().get_context_data(**kwargs) - - if "object_list" in context: - queryset = context["object_list"] - else: - queryset = self.get_related_queryset() - - paginator = Paginator(queryset, self.paginate_by) - page = self.request.GET.get("page") - - if page is None: - page = 1 - - try: - paginated_queryset = paginator.page(page) - except PageNotAnInteger: - paginated_queryset = paginator.page(1) - except EmptyPage: - paginated_queryset = paginator.page(paginator.num_pages) - - context["object_list"] = paginated_queryset - context["paginator"] = paginator - context["is_paginated"] = paginated_queryset.has_other_pages() - context["page_obj"] = paginated_queryset - - return context diff --git a/radis/rag/views.py b/radis/rag/views.py index 29f452c3..a22e631d 100644 --- a/radis/rag/views.py +++ b/radis/rag/views.py @@ -1,6 +1,11 @@ from typing import Any, Type, cast -from adit_radis_shared.common.mixins import HtmxOnlyMixin, PageSizeSelectMixin, RelatedFilterMixin +from adit_radis_shared.common.mixins import ( + HtmxOnlyMixin, + PageSizeSelectMixin, + RelatedFilterMixin, + RelatedPaginationMixin, +) from adit_radis_shared.common.types import AuthenticatedHttpRequest from django.conf import settings from django.contrib.auth.mixins import ( @@ -18,8 +23,6 @@ from django_tables2 import SingleTableMixin from formtools.wizard.views import SessionWizardView -# TODO: Change to adit_radis_shared.common.mixins.RelatedPaginationMixin -from radis.core.mixins import RelatedPaginationMixin from radis.core.views import ( AnalysisJobCancelView, AnalysisJobDeleteView, From 057df294e8eaa04e3992e9cb09ff26da1d5ba403 Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Wed, 24 Jul 2024 23:31:07 +0000 Subject: [PATCH 7/8] Fix compose up and down invoke tasks --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- tasks.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 7ea70183..a313f616 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,7 +2,7 @@ [[package]] name = "adit-radis-shared" -version = "0.6.1" +version = "0.6.2" description = "Shared Django apps between ADIT and RADIS" optional = false python-versions = "^3.11" @@ -37,8 +37,8 @@ whitenoise = "^6.0.0" [package.source] type = "git" url = "https://github.com/openradx/adit-radis-shared.git" -reference = "v0.6.1" -resolved_reference = "f6454b54b5233fba5dba82e53d1ed429fa21ee69" +reference = "v0.6.2" +resolved_reference = "932355e6bf7d5bb574ed011aa2967a43a9be8ca8" [[package]] name = "adrf" @@ -3555,4 +3555,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<4.0" -content-hash = "0a50389203dfa6ef8ba40e8ada2a79a31a12fe29cc74395bef94b5937b459c8e" +content-hash = "e03399c87981a03eb18d21b5b32fe4b14d1ca966932e2d9fa9d43184466f3db3" diff --git a/pyproject.toml b/pyproject.toml index 8bf5ad13..4500bc45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = ["medihack "] license = "GPL-3.0-or-later" [tool.poetry.dependencies] -adit-radis-shared = {git = "https://github.com/openradx/adit-radis-shared.git", rev = "v0.6.1"} +adit-radis-shared = {git = "https://github.com/openradx/adit-radis-shared.git", rev = "v0.6.2"} adrf = "^0.1.4" aiofiles = "^23.1.0" asyncinotify = "^4.0.1" diff --git a/tasks.py b/tasks.py index 2b3ca28b..80d2375e 100644 --- a/tasks.py +++ b/tasks.py @@ -38,7 +38,7 @@ def compose_up( else: profiles = ["cpu"] - invoke_tasks.compose_up(ctx=ctx, env=env, no_build=no_build, profile=profiles) + invoke_tasks.compose_up(ctx, env=env, no_build=no_build, profile=profiles) @task @@ -48,4 +48,4 @@ def compose_down( cleanup: bool = False, ): """Stop containers in specified environment""" - invoke_tasks.compose_down(ctx=ctx, env=env, cleanup=cleanup, profile=["cpu", "gpu"]) + invoke_tasks.compose_down(ctx, env=env, cleanup=cleanup, profile=["cpu", "gpu"]) From 82421e8057f682e7eed6fadc5b86077e2c73db17 Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Sat, 27 Jul 2024 22:54:44 +0000 Subject: [PATCH 8/8] Fix RAG tests --- radis/rag/factories.py | 30 +++++++++++-------- radis/rag/tests/unit/conftest.py | 3 +- .../unit/{test_task.py => test_processors.py} | 13 +++++--- 3 files changed, 28 insertions(+), 18 deletions(-) rename radis/rag/tests/unit/{test_task.py => test_processors.py} (81%) diff --git a/radis/rag/factories.py b/radis/rag/factories.py index c61a1b16..deef3aed 100644 --- a/radis/rag/factories.py +++ b/radis/rag/factories.py @@ -6,11 +6,14 @@ from radis.reports.factories import ModalityFactory from .models import Answer, Question, RagInstance, RagJob, RagTask +from .site import retrieval_providers T = TypeVar("T") fake = Faker() +MODALITIES = ("CT", "MR", "DX", "PT", "US") + class BaseDjangoModelFactory(Generic[T], factory.django.DjangoModelFactory): @classmethod @@ -18,39 +21,40 @@ def create(cls, *args, **kwargs) -> T: return super().create(*args, **kwargs) -SearchProviders = ("OpenSearch", "Vespa", "Elasticsearch") -PatientSexes = ["", "M", "F"] - - class RagJobFactory(BaseDjangoModelFactory): class Meta: model = RagJob title = factory.Faker("sentence", nb_words=3) - provider = factory.Faker("random_element", elements=SearchProviders) + provider = factory.Faker("random_element", elements=list(retrieval_providers.keys())) group = factory.SubFactory("adit_radis_shared.accounts.factories.GroupFactory") query = factory.Faker("word") language = factory.SubFactory("radis.reports.factories.LanguageFactory") study_date_from = factory.Faker("date") study_date_till = factory.Faker("date") study_description = factory.Faker("sentence", nb_words=5) - patient_sex = factory.Faker("random_element", elements=PatientSexes) + patient_sex = factory.Faker("random_element", elements=["M", "F", "O"]) age_from = factory.Faker("random_int", min=0, max=100) age_till = factory.Faker("random_int", min=0, max=100) @factory.post_generation def modalities(self, create, extracted, **kwargs): + """ + If called like: ReportFactory.create(modalities=["CT", "PT"]) it generates + a report with 2 modalities. If called without `modalities` argument, it + generates a random amount of modalities for the report. + """ if not create: return - self = cast(RagJob, self) + modalities = extracted + if modalities is None: + modalities = fake.random_elements(elements=MODALITIES, unique=True) - if extracted: - for modality in extracted: - self.modalities.add(modality) - else: - modality = ModalityFactory() - self.modalities.add(modality) + for modality in modalities: + # We can't call the create method of the factory as + # django_get_or_create would not be respected then + self.modalities.add(ModalityFactory(code=modality)) # type: ignore class QuestionFactory(BaseDjangoModelFactory[Question]): diff --git a/radis/rag/tests/unit/conftest.py b/radis/rag/tests/unit/conftest.py index c35236ee..81ad5f52 100644 --- a/radis/rag/tests/unit/conftest.py +++ b/radis/rag/tests/unit/conftest.py @@ -4,7 +4,7 @@ from radis.core.tests.unit.conftest import openai_chat_completions_mock # noqa from radis.rag.factories import QuestionFactory, RagInstanceFactory, RagJobFactory, RagTaskFactory -from radis.rag.models import RagTask +from radis.rag.models import RagJob, RagTask from radis.reports.factories import LanguageFactory, ReportFactory from radis.reports.models import Language @@ -20,6 +20,7 @@ def _create_rag_task( num_rag_instances: int = 5, ) -> RagTask: job = RagJobFactory.create( + status=RagJob.Status.PENDING, owner_id=user_with_group.id, owner=user_with_group, language=LanguageFactory.create(code=language_code), diff --git a/radis/rag/tests/unit/test_task.py b/radis/rag/tests/unit/test_processors.py similarity index 81% rename from radis/rag/tests/unit/test_task.py rename to radis/rag/tests/unit/test_processors.py index dd2c3354..070c80ad 100644 --- a/radis/rag/tests/unit/test_task.py +++ b/radis/rag/tests/unit/test_processors.py @@ -1,16 +1,19 @@ from unittest.mock import patch import pytest +from channels.db import database_sync_to_async +from django.db import close_old_connections from radis.rag.models import Answer, RagInstance from radis.rag.processors import RagTaskProcessor +@pytest.mark.asyncio @pytest.mark.django_db(transaction=True) -def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocker): +async def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocker): num_rag_instances = 5 num_questions = 5 - rag_task = create_rag_task( + rag_task = await database_sync_to_async(create_rag_task)( language_code="en", num_questions=num_questions, accepted_answer="Y", @@ -18,12 +21,12 @@ def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocke ) openai_mock = openai_chat_completions_mock("Yes") - process_rag_task_spy = mocker.spy(RagTaskProcessor, "process_rag_task") + process_rag_task_spy = mocker.spy(RagTaskProcessor, "process_task") process_rag_instance_spy = mocker.spy(RagTaskProcessor, "process_rag_instance") process_yes_or_no_question_spy = mocker.spy(RagTaskProcessor, "process_yes_or_no_question") with patch("openai.AsyncOpenAI", return_value=openai_mock): - RagTaskProcessor().start(rag_task) + await RagTaskProcessor(rag_task).start() rag_instances = rag_task.rag_instances.all() for instance in rag_instances: @@ -40,3 +43,5 @@ def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocke assert ( openai_mock.chat.completions.create.call_count == num_rag_instances * num_questions ) + + close_old_connections()