Skip to content

Commit

Permalink
Merge pull request #8 from openradx/rag-task-batching
Browse files Browse the repository at this point in the history
Rag task batching
  • Loading branch information
medihack committed Jul 20, 2024
2 parents 8389f10 + 52efbb6 commit b1fd5a2
Show file tree
Hide file tree
Showing 29 changed files with 700 additions and 150 deletions.
4 changes: 2 additions & 2 deletions compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ services:
llamacpp_cpu:
<<: *llamacpp
image: ghcr.io/ggerganov/llama.cpp:server
entrypoint: "/bin/bash -c '/llama-server -mu $${LLM_MODEL_URL} -c 512 --host 0.0.0.0 --port 8080'"
entrypoint: "/bin/bash -c '/llama-server -mu $${LLM_MODEL_URL} -c 4096 --host 0.0.0.0 --port 8080 --threads 8 --threads-http 8 --parallel 8 --cont-batching'"
profiles: ["cpu"]

llamacpp_gpu:
<<: *llamacpp
image: ghcr.io/ggerganov/llama.cpp:server-cuda
entrypoint: "/bin/bash -c '/llama-server -mu $${LLM_MODEL_URL} -ngl 99 -c 4096 --host 0.0.0.0 --port 8080'"
entrypoint: "/bin/bash -c '/llama-server -mu $${LLM_MODEL_URL} -ngl 99 -c 4096 --host 0.0.0.0 --port 8080 --threads 8 --threads-http 8 --parallel 8 --cont-batching'"
deploy:
resources:
reservations:
Expand Down
58 changes: 58 additions & 0 deletions radis/core/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Protocol

from adit_radis_shared.common.mixins import ViewProtocol
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator
from django.db.models.query import QuerySet
from django.http import HttpRequest


# TODO: Move this to adit_radis_shared package. PR: https://github.com/openradx/adit-radis-shared/pull/5
class RelatedPaginationMixinProtocol(ViewProtocol, Protocol):
request: HttpRequest
object_list: QuerySet
paginate_by: int

def get_object(self) -> Any: ...

def get_context_data(self, **kwargs) -> dict[str, Any]: ...

def get_related_queryset(self) -> QuerySet: ...


class RelatedPaginationMixin:
"""This mixin provides pagination for a related queryset. This makes it possible to
paginate a related queryset in a DetailView. The related queryset is obtained by
the `get_related_queryset()` method that must be implemented by the subclass.
If used in combination with `RelatedFilterMixin`, the `RelatedPaginationMixin` must be
inherited first."""

def get_related_queryset(self: RelatedPaginationMixinProtocol) -> QuerySet:
raise NotImplementedError("You must implement this method")

def get_context_data(self: RelatedPaginationMixinProtocol, **kwargs):
context = super().get_context_data(**kwargs)

if "object_list" in context:
queryset = context["object_list"]
else:
queryset = self.get_related_queryset()

paginator = Paginator(queryset, self.paginate_by)
page = self.request.GET.get("page")

if page is None:
page = 1

try:
paginated_queryset = paginator.page(page)
except PageNotAnInteger:
paginated_queryset = paginator.page(1)
except EmptyPage:
paginated_queryset = paginator.page(paginator.num_pages)

context["object_list"] = paginated_queryset
context["paginator"] = paginator
context["is_paginated"] = paginated_queryset.has_other_pages()
context["page_obj"] = paginated_queryset

return context
Empty file.
32 changes: 32 additions & 0 deletions radis/core/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio
from typing import Callable, ContextManager
from unittest.mock import MagicMock

import pytest
from faker import Faker


@pytest.fixture
def report_body() -> str:
report_body = Faker().sentences(nb=40)
return " ".join(report_body)


@pytest.fixture
def question_body() -> str:
question_body = Faker().sentences(nb=1)
return " ".join(question_body)


@pytest.fixture
def openai_chat_completions_mock() -> Callable[[str], ContextManager]:
def _openai_chat_completions_mock(content: str) -> ContextManager:
mock_openai = MagicMock()
mock_response = MagicMock(choices=[MagicMock(message=MagicMock(content=content))])
future = asyncio.Future()
future.set_result(mock_response)
mock_openai.chat.completions.create.return_value = future

return mock_openai

return _openai_chat_completions_mock
34 changes: 34 additions & 0 deletions radis/core/tests/unit/test_chat_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest.mock import patch

import pytest

from radis.core.utils.chat_client import AsyncChatClient


@pytest.mark.asyncio
async def test_ask_question(report_body, question_body, openai_chat_completions_mock):
openai_mock = openai_chat_completions_mock("Fake Answer")

with patch("openai.AsyncOpenAI", return_value=openai_mock):
answer = await AsyncChatClient().ask_question(report_body, "en", question_body)

assert answer == "Fake Answer"
assert openai_mock.chat.completions.create.call_count == 1


@pytest.mark.asyncio
async def test_ask_yes_no_question(report_body, question_body, openai_chat_completions_mock):
openai_yes_mock = openai_chat_completions_mock("Yes")
openai_no_mock = openai_chat_completions_mock("No")

with patch("openai.AsyncOpenAI", return_value=openai_yes_mock):
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body)

assert answer == "yes"
assert openai_yes_mock.chat.completions.create.call_count == 1

with patch("openai.AsyncOpenAI", return_value=openai_no_mock):
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body)

assert answer == "no"
assert openai_no_mock.chat.completions.create.call_count == 1
20 changes: 10 additions & 10 deletions radis/core/utils/chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from string import Template
from typing import Literal

import openai
from django.conf import settings
from openai import OpenAI

logger = logging.getLogger(__name__)

Expand All @@ -13,27 +13,27 @@
"""


class ChatClient:
class AsyncChatClient:
def __init__(self):
self._client = OpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")
self._client = openai.AsyncOpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")

def ask_question(
async def ask_question(
self, report_body: str, language: str, question: str, grammar: str | None = None
) -> str:
system_prompt = settings.CHAT_SYSTEM_PROMPT[language]
system = settings.CHAT_SYSTEM_PROMPT[language]
user_prompt = Template(settings.CHAT_USER_PROMPT[language]).substitute(
{"report": report_body, "question": question}
)

log_msg = f"Sending to LLM:\n[System] {system_prompt}\n[User] {user_prompt}"
log_msg = f"Sending to LLM:\n[System] {system}\n[User] {user_prompt}"
if grammar:
log_msg += f"\n[Grammar] {grammar}"
logger.debug(log_msg)

completion = self._client.chat.completions.create(
completion = await self._client.chat.completions.create(
model="none",
messages=[
{"role": "system", "content": system_prompt},
{"role": "system", "content": system},
{"role": "user", "content": user_prompt},
],
extra_body={"grammar": grammar},
Expand All @@ -45,7 +45,7 @@ def ask_question(

return answer

def ask_yes_no_question(
async def ask_yes_no_question(
self, report_body: str, language: str, question: str
) -> Literal["yes", "no"]:
grammar = Template(YES_NO_GRAMMAR).substitute(
Expand All @@ -55,7 +55,7 @@ def ask_yes_no_question(
}
)

llm_answer = self.ask_question(report_body, language, question, grammar)
llm_answer = await self.ask_question(report_body, language, question, grammar)

if llm_answer == settings.CHAT_ANSWER_YES[language]:
return "yes"
Expand Down
23 changes: 18 additions & 5 deletions radis/rag/admin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django import forms
from django.contrib import admin

from .models import Question, QuestionResult, RagJob, RagTask
from .models import Question, QuestionResult, RagInstance, RagJob, RagTask


class QuestionInline(admin.StackedInline):
Expand All @@ -17,11 +17,24 @@ class RagJobAdmin(admin.ModelAdmin):
admin.site.register(RagJob, RagJobAdmin)


class RagInstanceInline(admin.StackedInline):
model = RagInstance
extra = 1
ordering = ("id",)


class RagTaskAdmin(admin.ModelAdmin):
inlines = [RagInstanceInline]


admin.site.register(RagTask, RagTaskAdmin)


class QuestionResultInlineFormset(forms.BaseInlineFormSet):
def add_fields(self, form: forms.Form, index: int) -> None:
super().add_fields(form, index)
task: RagTask = self.instance
form.fields["question"].queryset = task.job.questions.all()
rag_instance = self.instance
form.fields["question"].queryset = rag_instance.task.job.questions.all()


class QuestionResultInline(admin.StackedInline):
Expand All @@ -31,8 +44,8 @@ class QuestionResultInline(admin.StackedInline):
formset = QuestionResultInlineFormset


class RagTaskAdmin(admin.ModelAdmin):
class RagInstanceAdmin(admin.ModelAdmin):
inlines = [QuestionResultInline]


admin.site.register(RagTask, RagTaskAdmin)
admin.site.register(RagInstance, RagInstanceAdmin)
91 changes: 91 additions & 0 deletions radis/rag/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Generic, TypeVar, cast

import factory
from faker import Faker

from radis.reports.factories import ModalityFactory

from .models import Answer, Question, RagInstance, RagJob, RagTask

T = TypeVar("T")

fake = Faker()


class BaseDjangoModelFactory(Generic[T], factory.django.DjangoModelFactory):
@classmethod
def create(cls, *args, **kwargs) -> T:
return super().create(*args, **kwargs)


SearchProviders = ("OpenSearch", "Vespa", "Elasticsearch")
PatientSexes = ["", "M", "F"]


class RagJobFactory(BaseDjangoModelFactory):
class Meta:
model = RagJob

title = factory.Faker("sentence", nb_words=3)
provider = factory.Faker("random_element", elements=SearchProviders)
group = factory.SubFactory("adit_radis_shared.accounts.factories.GroupFactory")
query = factory.Faker("word")
language = factory.SubFactory("radis.reports.factories.LanguageFactory")
study_date_from = factory.Faker("date")
study_date_till = factory.Faker("date")
study_description = factory.Faker("sentence", nb_words=5)
patient_sex = factory.Faker("random_element", elements=PatientSexes)
age_from = factory.Faker("random_int", min=0, max=100)
age_till = factory.Faker("random_int", min=0, max=100)

@factory.post_generation
def modalities(self, create, extracted, **kwargs):
if not create:
return

self = cast(RagJob, self)

if extracted:
for modality in extracted:
self.modalities.add(modality)
else:
modality = ModalityFactory()
self.modalities.add(modality)


class QuestionFactory(BaseDjangoModelFactory[Question]):
class Meta:
model = Question

job = factory.SubFactory("radis.rag.factories.RagJobFactory")
question = factory.Faker("sentence", nb_words=10)
accepted_answer = factory.Faker("random_element", elements=[a[0] for a in Answer.choices])


class RagTaskFactory(BaseDjangoModelFactory[RagTask]):
class Meta:
model = RagTask

job = factory.SubFactory("radis.rag.factories.RagJobFactory")


class RagInstanceFactory(BaseDjangoModelFactory[RagInstance]):
class Meta:
model = RagInstance

task = factory.SubFactory("radis.rag.factories.RagTaskFactory")

@factory.post_generation
def reports(self, create, extracted, **kwargs):
if not create:
return

self = cast(RagInstance, self)

if extracted:
for report in extracted:
self.reports.add(report)
else:
from radis.reports.factories import ReportFactory

self.reports.add(*[ReportFactory() for _ in range(3)])
4 changes: 2 additions & 2 deletions radis/rag/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from radis.core.filters import AnalysisJobFilter, AnalysisTaskFilter

from .models import RagJob, RagTask
from .models import RagInstance, RagJob, RagTask


class RagJobFilter(AnalysisJobFilter):
Expand All @@ -21,7 +21,7 @@ class RagResultFilter(django_filters.FilterSet):
request: HttpRequest

class Meta:
model = RagTask
model = RagInstance
fields = ("overall_result",)

def __init__(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit b1fd5a2

Please sign in to comment.