Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
medihack committed Aug 10, 2024
1 parent 921a80a commit f71dd79
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 39 deletions.
15 changes: 7 additions & 8 deletions radis/core/processors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import traceback

from channels.db import database_sync_to_async
from django.utils import timezone

from .models import AnalysisJob, AnalysisTask
Expand All @@ -13,7 +12,7 @@ class AnalysisTaskProcessor:
def __init__(self, task: AnalysisTask) -> None:
self.task = task

async def start(self) -> None:
def start(self) -> None:
task = self.task
job = task.job

Expand Down Expand Up @@ -42,17 +41,17 @@ async def start(self) -> None:
if job.status == job.Status.PENDING:
job.status = job.Status.IN_PROGRESS
job.started_at = timezone.now()
await job.asave()
job.save()

assert job.status == job.Status.IN_PROGRESS

# Prepare the task itself
task.status = AnalysisTask.Status.IN_PROGRESS
task.started_at = timezone.now()
await task.asave()
task.save()

try:
await self.process_task(task)
self.process_task(task)

# If the overwritten process_task method changes the status of the
# task itself then we leave it as it is. Otherwise if the status is
Expand All @@ -70,9 +69,9 @@ async def start(self) -> None:
finally:
logger.info("Task %s ended", task)
task.ended_at = timezone.now()
await task.asave()
await database_sync_to_async(job.update_job_state)()
task.save()
job.update_job_state()

async def process_task(self, task: AnalysisTask) -> None:
def process_task(self, task: AnalysisTask) -> None:
"""The derived class should process the task here."""
...
8 changes: 4 additions & 4 deletions radis/rag/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def other_reports(self, create, extracted, **kwargs):

self = cast(RagInstance, self)

if extracted:
for report in extracted:
self.other_reports.add(report)
else:
if extracted is None:
from radis.reports.factories import ReportFactory

self.other_reports.add(*[ReportFactory() for _ in range(3)])
else:
for report in extracted:
self.other_reports.add(report)
5 changes: 4 additions & 1 deletion radis/rag/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


class RagTaskProcessor(AnalysisTaskProcessor):
async def process_task(self, task: RagTask) -> None:
def process_task(self, task: RagTask) -> None:
asyncio.run(self.process_task_async(task))

async def process_task_async(self, task: RagTask) -> None:
language_code = task.job.language.code
client = AsyncChatClient()
sem = Semaphore(settings.RAG_LLM_CONCURRENCY_LIMIT)
Expand Down
19 changes: 10 additions & 9 deletions radis/rag/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from itertools import batched

from django.conf import settings
from django.db.models import Prefetch
from pebble import asynchronous
from procrastinate.contrib.django import app

from radis.reports.models import Report
Expand All @@ -18,14 +18,15 @@

@app.task(queue="llm")
async def process_rag_task(task_id: int) -> None:
task = await RagTask.objects.prefetch_related(
Prefetch(
"job",
queryset=RagJob.objects.prefetch_related("language"),
)
).aget(id=task_id)
processor = RagTaskProcessor(task)
await processor.start()
# We have to run RagTaskProcessor in a separate thread because it is
# creating an async loop itself.
@asynchronous.thread
def _process_tag_task(task_id: int) -> None:
task = RagTask.objects.get(id=task_id)
processor = RagTaskProcessor(task)
processor.start()

await _process_tag_task(task_id)


@app.task
Expand Down
10 changes: 6 additions & 4 deletions radis/rag/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from radis.rag.factories import QuestionFactory, RagInstanceFactory, RagJobFactory, RagTaskFactory
from radis.rag.models import RagJob, RagTask
from radis.reports.factories import LanguageFactory, ReportFactory
from radis.reports.models import Language


@pytest.fixture
Expand All @@ -19,11 +18,13 @@ def _create_rag_task(
accepted_answer: Optional[Literal["Y", "N"]] = None,
num_rag_instances: int = 5,
) -> RagTask:
language = LanguageFactory.create(code=language_code)

job = RagJobFactory.create(
status=RagJob.Status.PENDING,
owner_id=user_with_group.id,
owner=user_with_group,
language=LanguageFactory.create(code=language_code),
language=language,
)

if accepted_answer is not None:
Expand All @@ -32,9 +33,10 @@ def _create_rag_task(
QuestionFactory.create_batch(num_questions, job=job)

task = RagTaskFactory.create(job=job)

for _ in range(num_rag_instances):
report = ReportFactory.create(language=Language.objects.get(code=language_code))
RagInstanceFactory.create(task=task, reports=[report])
report = ReportFactory.create(language=language)
RagInstanceFactory.create(task=task, report=report, other_reports=[])

return task

Expand Down
21 changes: 8 additions & 13 deletions radis/rag/tests/unit/test_processors.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from unittest.mock import patch

import pytest
from channels.db import database_sync_to_async
from django.db import close_old_connections

from radis.rag.models import Answer, RagInstance
from radis.rag.processors import RagTaskProcessor


@pytest.mark.asyncio
@pytest.mark.django_db(transaction=True)
async def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocker):
def test_rag_task_processor(create_rag_task, openai_chat_completions_mock, mocker):
num_rag_instances = 5
num_questions = 5
rag_task = await database_sync_to_async(create_rag_task)(
rag_task = create_rag_task(
language_code="en",
num_questions=num_questions,
accepted_answer="Y",
Expand All @@ -26,22 +24,19 @@ async def test_rag_task_processor(create_rag_task, openai_chat_completions_mock,
process_yes_or_no_question_spy = mocker.spy(RagTaskProcessor, "process_yes_or_no_question")

with patch("openai.AsyncOpenAI", return_value=openai_mock):
await RagTaskProcessor(rag_task).start()
rag_instances = rag_task.rag_instances.all()
RagTaskProcessor(rag_task).start()

for instance in rag_instances:
for instance in rag_task.rag_instances.all():
assert instance.overall_result == RagInstance.Result.ACCEPTED
question_results = instance.results.all()
assert all(
[result.result == RagInstance.Result.ACCEPTED for result in question_results]
)
assert all([result.original_answer == Answer.YES for result in question_results])

assert process_rag_task_spy.call_count == 1
assert process_rag_instance_spy.call_count == num_rag_instances
assert process_yes_or_no_question_spy.call_count == num_rag_instances * num_questions
assert (
openai_mock.chat.completions.create.call_count == num_rag_instances * num_questions
)
assert process_rag_task_spy.call_count == 1
assert process_rag_instance_spy.call_count == num_rag_instances
assert process_yes_or_no_question_spy.call_count == num_rag_instances * num_questions
assert openai_mock.chat.completions.create.call_count == num_rag_instances * num_questions

close_old_connections()

0 comments on commit f71dd79

Please sign in to comment.