Skip to content

Commit

Permalink
Add: basic test for ProcessRagTask and AsyncChatClient
Browse files Browse the repository at this point in the history
  • Loading branch information
hummerichsander authored Jul 8, 2024
1 parent 6fc0c91 commit 722b3c0
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 3 deletions.
Empty file.
32 changes: 32 additions & 0 deletions radis/core/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio
from typing import Callable, ContextManager
from unittest.mock import MagicMock

import pytest
from faker import Faker


@pytest.fixture
def report_body() -> str:
report_body = Faker().sentences(nb=40)
return " ".join(report_body)


@pytest.fixture
def question_body() -> str:
question_body = Faker().sentences(nb=1)
return " ".join(question_body)


@pytest.fixture
def openai_chat_completions_mock() -> Callable[[str], ContextManager]:
def _openai_chat_completions_mock(content: str) -> ContextManager:
mock_openai = MagicMock()
mock_response = MagicMock(choices=[MagicMock(message=MagicMock(content=content))])
future = asyncio.Future()
future.set_result(mock_response)
mock_openai.chat.completions.create.return_value = future

return mock_openai

return _openai_chat_completions_mock
34 changes: 34 additions & 0 deletions radis/core/tests/unit/test_chat_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest.mock import patch

import pytest

from radis.core.utils.chat_client import AsyncChatClient


@pytest.mark.asyncio
async def test_ask_question(report_body, question_body, openai_chat_completions_mock):
openai_mock = openai_chat_completions_mock("Fake Answer")

with patch("openai.AsyncOpenAI", return_value=openai_mock):
answer = await AsyncChatClient().ask_question(report_body, "en", question_body)

assert answer == "Fake Answer"
assert openai_mock.chat.completions.create.call_count == 1


@pytest.mark.asyncio
async def test_ask_yes_no_question(report_body, question_body, openai_chat_completions_mock):
openai_yes_mock = openai_chat_completions_mock("Yes")
openai_no_mock = openai_chat_completions_mock("No")

with patch("openai.AsyncOpenAI", return_value=openai_yes_mock):
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body)

assert answer == "yes"
assert openai_yes_mock.chat.completions.create.call_count == 1

with patch("openai.AsyncOpenAI", return_value=openai_no_mock):
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body)

assert answer == "no"
assert openai_no_mock.chat.completions.create.call_count == 1
6 changes: 3 additions & 3 deletions radis/core/utils/chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from string import Template
from typing import Literal

import openai
from django.conf import settings
from openai import AsyncOpenAI, OpenAI

logger = logging.getLogger(__name__)

Expand All @@ -15,7 +15,7 @@

class ChatClient:
def __init__(self):
self._client = OpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")
self._client = openai.OpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")

def ask_question(
self, report_body: str, language: str, question: str, grammar: str | None = None
Expand Down Expand Up @@ -67,7 +67,7 @@ def ask_yes_no_question(

class AsyncChatClient:
def __init__(self):
self._client = AsyncOpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")
self._client = openai.AsyncOpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")

async def ask_question(
self, report_body: str, language: str, question: str, grammar: str | None = None
Expand Down
62 changes: 62 additions & 0 deletions radis/rag/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Generic, TypeVar

import factory
from faker import Faker

from .models import Answer, Question, RagJob, RagReportInstance, RagTask

T = TypeVar("T")

fake = Faker()


class BaseDjangoModelFactory(Generic[T], factory.django.DjangoModelFactory):
@classmethod
def create(cls, *args, **kwargs) -> T:
return super().create(*args, **kwargs)


SearchProviders = ("OpenSearch", "Vespa", "Elasticsearch")
PatientSexes = ["", "M", "F"]


class RagJobFactory(BaseDjangoModelFactory):
class Meta:
model = RagJob

title = factory.Faker("sentence", nb_words=3)
provider = factory.Faker("random_element", elements=SearchProviders)
group = factory.SubFactory("adit_radis_shared.accounts.factories.GroupFactory")
query = factory.Faker("word")
language = factory.SubFactory("radis.reports.factories.LanguageFactory")
# TODO: handle modalities
study_date_from = factory.Faker("date")
study_date_till = factory.Faker("date")
study_description = factory.Faker("sentence", nb_words=5)
patient_sex = factory.Faker("random_element", elements=PatientSexes)
age_from = factory.Faker("random_int", min=0, max=100)
age_till = factory.Faker("random_int", min=0, max=100)


class QuestionFactory(BaseDjangoModelFactory[Question]):
class Meta:
model = Question

job = factory.SubFactory("radis.rag.factories.RagJobFactory")
question = factory.Faker("sentence", nb_words=10)
accepted_answer = factory.Faker("random_element", elements=[a[0] for a in Answer.choices])


class RagTaskFactory(BaseDjangoModelFactory[RagTask]):
class Meta:
model = RagTask

job = factory.SubFactory("radis.rag.factories.RagJobFactory")


class RagReportInstanceFactory(BaseDjangoModelFactory[RagReportInstance]):
class Meta:
model = RagReportInstance

task = factory.SubFactory("radis.rag.factories.RagTaskFactory")
report = factory.SubFactory("radis.reports.factories.ReportFactory")
Empty file added radis/rag/tests/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions radis/rag/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from adit_radis_shared.accounts.factories import GroupFactory, UserFactory
from adit_radis_shared.accounts.models import User
from django.contrib.auth.models import Group


@pytest.fixture
def rag_group() -> Group:
group = GroupFactory()
# TODO: Add permissions to the group
return group


@pytest.fixture
def user_with_group(rag_group) -> User:
user = UserFactory()
user.groups.add(rag_group)
return user
Empty file.
45 changes: 45 additions & 0 deletions radis/rag/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Callable, Literal, Optional

import pytest

from radis.core.tests.unit.conftest import openai_chat_completions_mock # noqa
from radis.rag.factories import (
QuestionFactory,
RagJobFactory,
RagReportInstanceFactory,
RagTaskFactory,
)
from radis.rag.models import RagTask
from radis.reports.factories import LanguageFactory, ReportFactory
from radis.reports.models import Language


@pytest.fixture
def create_rag_task(
user_with_group,
) -> Callable[[Literal["en", "de"], int, Optional[Literal["Y", "N"]], int], RagTask]:
def _create_rag_task(
language_code: Literal["en", "de"] = "en",
num_questions: int = 5,
accepted_answer: Optional[Literal["Y", "N"]] = None,
num_report_instances: int = 5,
) -> RagTask:
job = RagJobFactory.create(
owner_id=user_with_group.id,
owner=user_with_group,
language=LanguageFactory.create(code=language_code),
)

if accepted_answer is not None:
QuestionFactory.create_batch(num_questions, job=job, accepted_answer=accepted_answer)
else:
QuestionFactory.create_batch(num_questions, job=job)

task = RagTaskFactory.create(job=job)
for _ in range(num_report_instances):
report = ReportFactory.create(language=Language.objects.get(code=language_code))
RagReportInstanceFactory.create(report=report, task=task)

return task

return _create_rag_task
43 changes: 43 additions & 0 deletions radis/rag/tests/unit/test_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import patch

import pytest

from radis.rag.models import Answer, RagReportInstance
from radis.rag.tasks import ProcessRagTask


@pytest.mark.django_db(transaction=True)
def test_process_rag_task(create_rag_task, openai_chat_completions_mock, mocker):
num_report_instances = 5
num_questions = 5
rag_task = create_rag_task(
language_code="en",
num_questions=num_questions,
accepted_answer="Y",
num_report_instances=num_report_instances,
)

openai_mock = openai_chat_completions_mock("Yes")
process_rag_task_spy = mocker.spy(ProcessRagTask, "process_rag_task")
process_report_instance_spy = mocker.spy(ProcessRagTask, "process_report_instance")
process_yes_or_no_question_spy = mocker.spy(ProcessRagTask, "process_yes_or_no_question")

with patch("openai.AsyncOpenAI", return_value=openai_mock):
ProcessRagTask().process_task(rag_task)
report_instances = rag_task.report_instances.all()

for instance in report_instances:
assert instance.overall_result == RagReportInstance.Result.ACCEPTED
question_results = instance.results.all()
assert all(
[result.result == RagReportInstance.Result.ACCEPTED for result in question_results]
)
assert all([result.original_answer == Answer.YES for result in question_results])

assert process_rag_task_spy.call_count == 1
assert process_report_instance_spy.call_count == num_report_instances
assert process_yes_or_no_question_spy.call_count == num_report_instances * num_questions
assert (
openai_mock.chat.completions.create.call_count
== num_report_instances * num_questions
)

0 comments on commit 722b3c0

Please sign in to comment.