Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rag task batching #8

Merged
merged 18 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
66f873b
Enable batching in RagTasks and concurrent requests to llm
hummerichsander Jun 19, 2024
79b5adb
async db queries in tasks and some minor frontend fixes
hummerichsander Jun 27, 2024
8f32795
handle not terminated result badges
hummerichsander Jun 27, 2024
3c2baf2
remove triple register task
hummerichsander Jun 27, 2024
dc1f30a
settings descriptions
hummerichsander Jun 27, 2024
92797f9
add comments for batching settings
hummerichsander Jun 30, 2024
7fafa07
added return type hints @ RagResultsView
hummerichsander Jun 30, 2024
6ff6d48
remove pagination attributes from RagResultsListView
hummerichsander Jun 30, 2024
6fc0c91
Fix pagination in RagResultListView via `RelatedPaginationMixin`
hummerichsander Jul 1, 2024
722b3c0
Add: basic test for ProcessRagTask and AsyncChatClient
hummerichsander Jul 8, 2024
1fbcc8f
`close_old_connections` in `process_rag_task`
hummerichsander Jul 9, 2024
001fa95
Set defaults for RAG_TASK_BATCH_SIZE and RAG_LLM_CONCURRENCY_LIMIT to 1
hummerichsander Jul 9, 2024
2435e44
remove `ChatClient` and use `AsyncChatClient` in`report_chat_view`
hummerichsander Jul 16, 2024
653ff1f
remove `self._client` from `ProcessRagTask`
hummerichsander Jul 16, 2024
1854658
add `ModalitiesFactory` and apply formatting and linting
hummerichsander Jul 16, 2024
91460c2
set default values: `RAG_TASK_BATCH_SIZE = 64`and `RAG_LLM_CONCURRENC…
hummerichsander Jul 16, 2024
915c132
Rename `RagReportInstance` to `RagInstance` and replace ForeignKey f…
hummerichsander Jul 18, 2024
52efbb6
Merge remote-tracking branch 'origin/main' into rag-task-batching
hummerichsander Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions radis/core/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Protocol

from adit_radis_shared.common.mixins import ViewProtocol
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator
from django.db.models.query import QuerySet
from django.http import HttpRequest


# TODO: Move this to adit_radis_shared package. PR: https://github.com/openradx/adit-radis-shared/pull/5
class RelatedPaginationMixinProtocol(ViewProtocol, Protocol):
request: HttpRequest
object_list: QuerySet
paginate_by: int

def get_object(self) -> Any: ...

def get_context_data(self, **kwargs) -> dict[str, Any]: ...

def get_related_queryset(self) -> QuerySet: ...


class RelatedPaginationMixin:
"""This mixin provides pagination for a related queryset. This makes it possible to
paginate a related queryset in a DetailView. The related queryset is obtained by
the `get_related_queryset()` method that must be implemented by the subclass.
If used in combination with `RelatedFilterMixin`, the `RelatedPaginationMixin` must be
inherited first."""

def get_related_queryset(self: RelatedPaginationMixinProtocol) -> QuerySet:
raise NotImplementedError("You must implement this method")

def get_context_data(self: RelatedPaginationMixinProtocol, **kwargs):
context = super().get_context_data(**kwargs)

if "object_list" in context:
queryset = context["object_list"]
else:
queryset = self.get_related_queryset()

paginator = Paginator(queryset, self.paginate_by)
page = self.request.GET.get("page")

if page is None:
page = 1

try:
paginated_queryset = paginator.page(page)
except PageNotAnInteger:
paginated_queryset = paginator.page(1)
except EmptyPage:
paginated_queryset = paginator.page(paginator.num_pages)

context["object_list"] = paginated_queryset
context["paginator"] = paginator
context["is_paginated"] = paginated_queryset.has_other_pages()
context["page_obj"] = paginated_queryset

return context
Empty file.
32 changes: 32 additions & 0 deletions radis/core/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio
from typing import Callable, ContextManager
from unittest.mock import MagicMock

import pytest
from faker import Faker


@pytest.fixture
def report_body() -> str:
report_body = Faker().sentences(nb=40)
return " ".join(report_body)


@pytest.fixture
def question_body() -> str:
question_body = Faker().sentences(nb=1)
return " ".join(question_body)


@pytest.fixture
def openai_chat_completions_mock() -> Callable[[str], ContextManager]:
def _openai_chat_completions_mock(content: str) -> ContextManager:
mock_openai = MagicMock()
mock_response = MagicMock(choices=[MagicMock(message=MagicMock(content=content))])
future = asyncio.Future()
future.set_result(mock_response)
mock_openai.chat.completions.create.return_value = future

return mock_openai

return _openai_chat_completions_mock
34 changes: 34 additions & 0 deletions radis/core/tests/unit/test_chat_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest.mock import patch

import pytest

from radis.core.utils.chat_client import AsyncChatClient


@pytest.mark.asyncio
async def test_ask_question(report_body, question_body, openai_chat_completions_mock):
openai_mock = openai_chat_completions_mock("Fake Answer")

with patch("openai.AsyncOpenAI", return_value=openai_mock):
answer = await AsyncChatClient().ask_question(report_body, "en", question_body)

assert answer == "Fake Answer"
assert openai_mock.chat.completions.create.call_count == 1


@pytest.mark.asyncio
async def test_ask_yes_no_question(report_body, question_body, openai_chat_completions_mock):
openai_yes_mock = openai_chat_completions_mock("Yes")
openai_no_mock = openai_chat_completions_mock("No")

with patch("openai.AsyncOpenAI", return_value=openai_yes_mock):
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body)

assert answer == "yes"
assert openai_yes_mock.chat.completions.create.call_count == 1

with patch("openai.AsyncOpenAI", return_value=openai_no_mock):
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body)

assert answer == "no"
assert openai_no_mock.chat.completions.create.call_count == 1
56 changes: 54 additions & 2 deletions radis/core/utils/chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from string import Template
from typing import Literal

import openai
from django.conf import settings
from openai import OpenAI

logger = logging.getLogger(__name__)

Expand All @@ -15,7 +15,7 @@

class ChatClient:
hummerichsander marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
self._client = OpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")
self._client = openai.OpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")

def ask_question(
self, report_body: str, language: str, question: str, grammar: str | None = None
Expand Down Expand Up @@ -63,3 +63,55 @@ def ask_yes_no_question(
return "no"
else:
raise ValueError(f"Unexpected answer: {llm_answer}")


class AsyncChatClient:
def __init__(self):
self._client = openai.AsyncOpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")

async def ask_question(
self, report_body: str, language: str, question: str, grammar: str | None = None
) -> str:
system = settings.CHAT_SYSTEM_PROMPT[language]
user_prompt = Template(settings.CHAT_USER_PROMPT[language]).substitute(
{"report": report_body, "question": question}
)

log_msg = f"Sending to LLM:\n[System] {system}\n[User] {user_prompt}"
if grammar:
log_msg += f"\n[Grammar] {grammar}"
logger.debug(log_msg)

completion = await self._client.chat.completions.create(
model="none",
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user_prompt},
],
extra_body={"grammar": grammar},
)

answer = completion.choices[0].message.content
assert answer is not None
logger.debug("Received from LLM: %s", answer)

return answer

async def ask_yes_no_question(
self, report_body: str, language: str, question: str
) -> Literal["yes", "no"]:
grammar = Template(YES_NO_GRAMMAR).substitute(
{
"yes": settings.CHAT_ANSWER_YES[language],
"no": settings.CHAT_ANSWER_NO[language],
}
)

llm_answer = await self.ask_question(report_body, language, question, grammar)

if llm_answer == settings.CHAT_ANSWER_YES[language]:
return "yes"
elif llm_answer == settings.CHAT_ANSWER_NO[language]:
return "no"
else:
raise ValueError(f"Unexpected answer: {llm_answer}")
23 changes: 18 additions & 5 deletions radis/rag/admin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django import forms
from django.contrib import admin

from .models import Question, QuestionResult, RagJob, RagTask
from .models import Question, QuestionResult, RagJob, RagReportInstance, RagTask


class QuestionInline(admin.StackedInline):
Expand All @@ -17,11 +17,24 @@ class RagJobAdmin(admin.ModelAdmin):
admin.site.register(RagJob, RagJobAdmin)


class RagReportInstanceInline(admin.StackedInline):
model = RagReportInstance
extra = 1
ordering = ("id",)


class RagTaskAdmin(admin.ModelAdmin):
inlines = [RagReportInstanceInline]


admin.site.register(RagTask, RagTaskAdmin)


class QuestionResultInlineFormset(forms.BaseInlineFormSet):
def add_fields(self, form: forms.Form, index: int) -> None:
super().add_fields(form, index)
task: RagTask = self.instance
form.fields["question"].queryset = task.job.questions.all()
report_instance = self.instance
form.fields["question"].queryset = report_instance.task.job.questions.all()


class QuestionResultInline(admin.StackedInline):
Expand All @@ -31,8 +44,8 @@ class QuestionResultInline(admin.StackedInline):
formset = QuestionResultInlineFormset


class RagTaskAdmin(admin.ModelAdmin):
class RagReportInstanceAdmin(admin.ModelAdmin):
inlines = [QuestionResultInline]


admin.site.register(RagTask, RagTaskAdmin)
admin.site.register(RagReportInstance, RagReportInstanceAdmin)
62 changes: 62 additions & 0 deletions radis/rag/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Generic, TypeVar

import factory
from faker import Faker

from .models import Answer, Question, RagJob, RagReportInstance, RagTask

T = TypeVar("T")

fake = Faker()


class BaseDjangoModelFactory(Generic[T], factory.django.DjangoModelFactory):
@classmethod
def create(cls, *args, **kwargs) -> T:
return super().create(*args, **kwargs)


SearchProviders = ("OpenSearch", "Vespa", "Elasticsearch")
PatientSexes = ["", "M", "F"]


class RagJobFactory(BaseDjangoModelFactory):
class Meta:
model = RagJob

title = factory.Faker("sentence", nb_words=3)
provider = factory.Faker("random_element", elements=SearchProviders)
group = factory.SubFactory("adit_radis_shared.accounts.factories.GroupFactory")
query = factory.Faker("word")
language = factory.SubFactory("radis.reports.factories.LanguageFactory")
# TODO: handle modalities
hummerichsander marked this conversation as resolved.
Show resolved Hide resolved
study_date_from = factory.Faker("date")
study_date_till = factory.Faker("date")
study_description = factory.Faker("sentence", nb_words=5)
patient_sex = factory.Faker("random_element", elements=PatientSexes)
age_from = factory.Faker("random_int", min=0, max=100)
age_till = factory.Faker("random_int", min=0, max=100)


class QuestionFactory(BaseDjangoModelFactory[Question]):
class Meta:
model = Question

job = factory.SubFactory("radis.rag.factories.RagJobFactory")
question = factory.Faker("sentence", nb_words=10)
accepted_answer = factory.Faker("random_element", elements=[a[0] for a in Answer.choices])


class RagTaskFactory(BaseDjangoModelFactory[RagTask]):
class Meta:
model = RagTask

job = factory.SubFactory("radis.rag.factories.RagJobFactory")


class RagReportInstanceFactory(BaseDjangoModelFactory[RagReportInstance]):
class Meta:
model = RagReportInstance

task = factory.SubFactory("radis.rag.factories.RagTaskFactory")
report = factory.SubFactory("radis.reports.factories.ReportFactory")
4 changes: 2 additions & 2 deletions radis/rag/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from radis.core.filters import AnalysisJobFilter, AnalysisTaskFilter

from .models import RagJob, RagTask
from .models import RagJob, RagReportInstance, RagTask


class RagJobFilter(AnalysisJobFilter):
Expand All @@ -21,7 +21,7 @@ class RagResultFilter(django_filters.FilterSet):
request: HttpRequest

class Meta:
model = RagTask
model = RagReportInstance
fields = ("overall_result",)

def __init__(self, *args, **kwargs):
Expand Down
Loading
Loading