diff --git a/radis/core/processors.py b/radis/core/processors.py index 99331c17..1a39802b 100644 --- a/radis/core/processors.py +++ b/radis/core/processors.py @@ -1,7 +1,6 @@ import logging import traceback -from channels.db import database_sync_to_async from django.utils import timezone from .models import AnalysisJob, AnalysisTask @@ -13,7 +12,7 @@ class AnalysisTaskProcessor: def __init__(self, task: AnalysisTask) -> None: self.task = task - async def start(self) -> None: + def start(self) -> None: task = self.task job = task.job @@ -42,17 +41,17 @@ async def start(self) -> None: if job.status == job.Status.PENDING: job.status = job.Status.IN_PROGRESS job.started_at = timezone.now() - await job.asave() + job.save() assert job.status == job.Status.IN_PROGRESS # Prepare the task itself task.status = AnalysisTask.Status.IN_PROGRESS task.started_at = timezone.now() - await task.asave() + task.save() try: - await self.process_task(task) + self.process_task(task) # If the overwritten process_task method changes the status of the # task itself then we leave it as it is. Otherwise if the status is @@ -70,9 +69,9 @@ async def start(self) -> None: finally: logger.info("Task %s ended", task) task.ended_at = timezone.now() - await task.asave() - await database_sync_to_async(job.update_job_state)() + task.save() + job.update_job_state() - async def process_task(self, task: AnalysisTask) -> None: + def process_task(self, task: AnalysisTask) -> None: """The derived class should process the task here.""" ... diff --git a/radis/rag/factories.py b/radis/rag/factories.py index 9b06c48a..c591d73c 100644 --- a/radis/rag/factories.py +++ b/radis/rag/factories.py @@ -87,10 +87,10 @@ def other_reports(self, create, extracted, **kwargs): self = cast(RagInstance, self) - if extracted: - for report in extracted: - self.other_reports.add(report) - else: + if extracted is None: from radis.reports.factories import ReportFactory self.other_reports.add(*[ReportFactory() for _ in range(3)]) + else: + for report in extracted: + self.other_reports.add(report) diff --git a/radis/rag/processors.py b/radis/rag/processors.py index 8e8570b4..06bfe855 100644 --- a/radis/rag/processors.py +++ b/radis/rag/processors.py @@ -16,7 +16,10 @@ class RagTaskProcessor(AnalysisTaskProcessor): - async def process_task(self, task: RagTask) -> None: + def process_task(self, task: RagTask) -> None: + asyncio.run(self.process_task_async(task)) + + async def process_task_async(self, task: RagTask) -> None: language_code = task.job.language.code client = AsyncChatClient() sem = Semaphore(settings.RAG_LLM_CONCURRENCY_LIMIT) diff --git a/radis/rag/tasks.py b/radis/rag/tasks.py index 3b9da83c..68ab077d 100644 --- a/radis/rag/tasks.py +++ b/radis/rag/tasks.py @@ -2,7 +2,7 @@ from itertools import batched from django.conf import settings -from django.db.models import Prefetch +from pebble import asynchronous from procrastinate.contrib.django import app from radis.reports.models import Report @@ -18,14 +18,15 @@ @app.task(queue="llm") async def process_rag_task(task_id: int) -> None: - task = await RagTask.objects.prefetch_related( - Prefetch( - "job", - queryset=RagJob.objects.prefetch_related("language"), - ) - ).aget(id=task_id) - processor = RagTaskProcessor(task) - await processor.start() + # We have to run RagTaskProcessor in a separate thread because it is + # creating an async loop itself. + @asynchronous.thread + def _process_tag_task(task_id: int) -> None: + task = RagTask.objects.get(id=task_id) + processor = RagTaskProcessor(task) + processor.start() + + await _process_tag_task(task_id) @app.task diff --git a/radis/rag/tests/unit/conftest.py b/radis/rag/tests/unit/conftest.py index 81ad5f52..1305d972 100644 --- a/radis/rag/tests/unit/conftest.py +++ b/radis/rag/tests/unit/conftest.py @@ -6,7 +6,6 @@ from radis.rag.factories import QuestionFactory, RagInstanceFactory, RagJobFactory, RagTaskFactory from radis.rag.models import RagJob, RagTask from radis.reports.factories import LanguageFactory, ReportFactory -from radis.reports.models import Language @pytest.fixture @@ -19,11 +18,13 @@ def _create_rag_task( accepted_answer: Optional[Literal["Y", "N"]] = None, num_rag_instances: int = 5, ) -> RagTask: + language = LanguageFactory.create(code=language_code) + job = RagJobFactory.create( status=RagJob.Status.PENDING, owner_id=user_with_group.id, owner=user_with_group, - language=LanguageFactory.create(code=language_code), + language=language, ) if accepted_answer is not None: @@ -32,9 +33,10 @@ def _create_rag_task( QuestionFactory.create_batch(num_questions, job=job) task = RagTaskFactory.create(job=job) + for _ in range(num_rag_instances): - report = ReportFactory.create(language=Language.objects.get(code=language_code)) - RagInstanceFactory.create(task=task, reports=[report]) + report = ReportFactory.create(language=language) + RagInstanceFactory.create(task=task, report=report, other_reports=[]) return task diff --git a/radis/rag/tests/unit/test_processors.py b/radis/rag/tests/unit/test_processors.py index 070c80ad..a74b0eb4 100644 --- a/radis/rag/tests/unit/test_processors.py +++ b/radis/rag/tests/unit/test_processors.py @@ -1,19 +1,17 @@ from unittest.mock import patch import pytest -from channels.db import database_sync_to_async from django.db import close_old_connections from radis.rag.models import Answer, RagInstance from radis.rag.processors import RagTaskProcessor -@pytest.mark.asyncio @pytest.mark.django_db(transaction=True) -async def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocker): +def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocker): num_rag_instances = 5 num_questions = 5 - rag_task = await database_sync_to_async(create_rag_task)( + rag_task = create_rag_task( language_code="en", num_questions=num_questions, accepted_answer="Y", @@ -26,10 +24,9 @@ async def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, process_yes_or_no_question_spy = mocker.spy(RagTaskProcessor, "process_yes_or_no_question") with patch("openai.AsyncOpenAI", return_value=openai_mock): - await RagTaskProcessor(rag_task).start() - rag_instances = rag_task.rag_instances.all() + RagTaskProcessor(rag_task).start() - for instance in rag_instances: + for instance in rag_task.rag_instances.all(): assert instance.overall_result == RagInstance.Result.ACCEPTED question_results = instance.results.all() assert all( @@ -37,11 +34,9 @@ async def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, ) assert all([result.original_answer == Answer.YES for result in question_results]) - assert process_rag_task_spy.call_count == 1 - assert process_rag_instance_spy.call_count == num_rag_instances - assert process_yes_or_no_question_spy.call_count == num_rag_instances * num_questions - assert ( - openai_mock.chat.completions.create.call_count == num_rag_instances * num_questions - ) + assert process_rag_task_spy.call_count == 1 + assert process_rag_instance_spy.call_count == num_rag_instances + assert process_yes_or_no_question_spy.call_count == num_rag_instances * num_questions + assert openai_mock.chat.completions.create.call_count == num_rag_instances * num_questions close_old_connections()