-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add: basic test for ProcessRagTask and AsyncChatClient
- Loading branch information
1 parent
6fc0c91
commit 722b3c0
Showing
10 changed files
with
237 additions
and
3 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import asyncio | ||
from typing import Callable, ContextManager | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
from faker import Faker | ||
|
||
|
||
@pytest.fixture | ||
def report_body() -> str: | ||
report_body = Faker().sentences(nb=40) | ||
return " ".join(report_body) | ||
|
||
|
||
@pytest.fixture | ||
def question_body() -> str: | ||
question_body = Faker().sentences(nb=1) | ||
return " ".join(question_body) | ||
|
||
|
||
@pytest.fixture | ||
def openai_chat_completions_mock() -> Callable[[str], ContextManager]: | ||
def _openai_chat_completions_mock(content: str) -> ContextManager: | ||
mock_openai = MagicMock() | ||
mock_response = MagicMock(choices=[MagicMock(message=MagicMock(content=content))]) | ||
future = asyncio.Future() | ||
future.set_result(mock_response) | ||
mock_openai.chat.completions.create.return_value = future | ||
|
||
return mock_openai | ||
|
||
return _openai_chat_completions_mock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from radis.core.utils.chat_client import AsyncChatClient | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_ask_question(report_body, question_body, openai_chat_completions_mock): | ||
openai_mock = openai_chat_completions_mock("Fake Answer") | ||
|
||
with patch("openai.AsyncOpenAI", return_value=openai_mock): | ||
answer = await AsyncChatClient().ask_question(report_body, "en", question_body) | ||
|
||
assert answer == "Fake Answer" | ||
assert openai_mock.chat.completions.create.call_count == 1 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_ask_yes_no_question(report_body, question_body, openai_chat_completions_mock): | ||
openai_yes_mock = openai_chat_completions_mock("Yes") | ||
openai_no_mock = openai_chat_completions_mock("No") | ||
|
||
with patch("openai.AsyncOpenAI", return_value=openai_yes_mock): | ||
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body) | ||
|
||
assert answer == "yes" | ||
assert openai_yes_mock.chat.completions.create.call_count == 1 | ||
|
||
with patch("openai.AsyncOpenAI", return_value=openai_no_mock): | ||
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body) | ||
|
||
assert answer == "no" | ||
assert openai_no_mock.chat.completions.create.call_count == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import Generic, TypeVar | ||
|
||
import factory | ||
from faker import Faker | ||
|
||
from .models import Answer, Question, RagJob, RagReportInstance, RagTask | ||
|
||
T = TypeVar("T") | ||
|
||
fake = Faker() | ||
|
||
|
||
class BaseDjangoModelFactory(Generic[T], factory.django.DjangoModelFactory): | ||
@classmethod | ||
def create(cls, *args, **kwargs) -> T: | ||
return super().create(*args, **kwargs) | ||
|
||
|
||
SearchProviders = ("OpenSearch", "Vespa", "Elasticsearch") | ||
PatientSexes = ["", "M", "F"] | ||
|
||
|
||
class RagJobFactory(BaseDjangoModelFactory): | ||
class Meta: | ||
model = RagJob | ||
|
||
title = factory.Faker("sentence", nb_words=3) | ||
provider = factory.Faker("random_element", elements=SearchProviders) | ||
group = factory.SubFactory("adit_radis_shared.accounts.factories.GroupFactory") | ||
query = factory.Faker("word") | ||
language = factory.SubFactory("radis.reports.factories.LanguageFactory") | ||
# TODO: handle modalities | ||
study_date_from = factory.Faker("date") | ||
study_date_till = factory.Faker("date") | ||
study_description = factory.Faker("sentence", nb_words=5) | ||
patient_sex = factory.Faker("random_element", elements=PatientSexes) | ||
age_from = factory.Faker("random_int", min=0, max=100) | ||
age_till = factory.Faker("random_int", min=0, max=100) | ||
|
||
|
||
class QuestionFactory(BaseDjangoModelFactory[Question]): | ||
class Meta: | ||
model = Question | ||
|
||
job = factory.SubFactory("radis.rag.factories.RagJobFactory") | ||
question = factory.Faker("sentence", nb_words=10) | ||
accepted_answer = factory.Faker("random_element", elements=[a[0] for a in Answer.choices]) | ||
|
||
|
||
class RagTaskFactory(BaseDjangoModelFactory[RagTask]): | ||
class Meta: | ||
model = RagTask | ||
|
||
job = factory.SubFactory("radis.rag.factories.RagJobFactory") | ||
|
||
|
||
class RagReportInstanceFactory(BaseDjangoModelFactory[RagReportInstance]): | ||
class Meta: | ||
model = RagReportInstance | ||
|
||
task = factory.SubFactory("radis.rag.factories.RagTaskFactory") | ||
report = factory.SubFactory("radis.reports.factories.ReportFactory") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pytest | ||
from adit_radis_shared.accounts.factories import GroupFactory, UserFactory | ||
from adit_radis_shared.accounts.models import User | ||
from django.contrib.auth.models import Group | ||
|
||
|
||
@pytest.fixture | ||
def rag_group() -> Group: | ||
group = GroupFactory() | ||
# TODO: Add permissions to the group | ||
return group | ||
|
||
|
||
@pytest.fixture | ||
def user_with_group(rag_group) -> User: | ||
user = UserFactory() | ||
user.groups.add(rag_group) | ||
return user |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Callable, Literal, Optional | ||
|
||
import pytest | ||
|
||
from radis.core.tests.unit.conftest import openai_chat_completions_mock # noqa | ||
from radis.rag.factories import ( | ||
QuestionFactory, | ||
RagJobFactory, | ||
RagReportInstanceFactory, | ||
RagTaskFactory, | ||
) | ||
from radis.rag.models import RagTask | ||
from radis.reports.factories import LanguageFactory, ReportFactory | ||
from radis.reports.models import Language | ||
|
||
|
||
@pytest.fixture | ||
def create_rag_task( | ||
user_with_group, | ||
) -> Callable[[Literal["en", "de"], int, Optional[Literal["Y", "N"]], int], RagTask]: | ||
def _create_rag_task( | ||
language_code: Literal["en", "de"] = "en", | ||
num_questions: int = 5, | ||
accepted_answer: Optional[Literal["Y", "N"]] = None, | ||
num_report_instances: int = 5, | ||
) -> RagTask: | ||
job = RagJobFactory.create( | ||
owner_id=user_with_group.id, | ||
owner=user_with_group, | ||
language=LanguageFactory.create(code=language_code), | ||
) | ||
|
||
if accepted_answer is not None: | ||
QuestionFactory.create_batch(num_questions, job=job, accepted_answer=accepted_answer) | ||
else: | ||
QuestionFactory.create_batch(num_questions, job=job) | ||
|
||
task = RagTaskFactory.create(job=job) | ||
for _ in range(num_report_instances): | ||
report = ReportFactory.create(language=Language.objects.get(code=language_code)) | ||
RagReportInstanceFactory.create(report=report, task=task) | ||
|
||
return task | ||
|
||
return _create_rag_task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from radis.rag.models import Answer, RagReportInstance | ||
from radis.rag.tasks import ProcessRagTask | ||
|
||
|
||
@pytest.mark.django_db(transaction=True) | ||
def test_process_rag_task(create_rag_task, openai_chat_completions_mock, mocker): | ||
num_report_instances = 5 | ||
num_questions = 5 | ||
rag_task = create_rag_task( | ||
language_code="en", | ||
num_questions=num_questions, | ||
accepted_answer="Y", | ||
num_report_instances=num_report_instances, | ||
) | ||
|
||
openai_mock = openai_chat_completions_mock("Yes") | ||
process_rag_task_spy = mocker.spy(ProcessRagTask, "process_rag_task") | ||
process_report_instance_spy = mocker.spy(ProcessRagTask, "process_report_instance") | ||
process_yes_or_no_question_spy = mocker.spy(ProcessRagTask, "process_yes_or_no_question") | ||
|
||
with patch("openai.AsyncOpenAI", return_value=openai_mock): | ||
ProcessRagTask().process_task(rag_task) | ||
report_instances = rag_task.report_instances.all() | ||
|
||
for instance in report_instances: | ||
assert instance.overall_result == RagReportInstance.Result.ACCEPTED | ||
question_results = instance.results.all() | ||
assert all( | ||
[result.result == RagReportInstance.Result.ACCEPTED for result in question_results] | ||
) | ||
assert all([result.original_answer == Answer.YES for result in question_results]) | ||
|
||
assert process_rag_task_spy.call_count == 1 | ||
assert process_report_instance_spy.call_count == num_report_instances | ||
assert process_yes_or_no_question_spy.call_count == num_report_instances * num_questions | ||
assert ( | ||
openai_mock.chat.completions.create.call_count | ||
== num_report_instances * num_questions | ||
) |