Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rag task batching #8

Merged
merged 18 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
66f873b
Enable batching in RagTasks and concurrent requests to llm
hummerichsander Jun 19, 2024
79b5adb
async db queries in tasks and some minor frontend fixes
hummerichsander Jun 27, 2024
8f32795
handle not terminated result badges
hummerichsander Jun 27, 2024
3c2baf2
remove triple register task
hummerichsander Jun 27, 2024
dc1f30a
settings descriptions
hummerichsander Jun 27, 2024
92797f9
add comments for batching settings
hummerichsander Jun 30, 2024
7fafa07
added return type hints @ RagResultsView
hummerichsander Jun 30, 2024
6ff6d48
remove pagination attributes from RagResultsListView
hummerichsander Jun 30, 2024
6fc0c91
Fix pagination in RagResultListView via `RelatedPaginationMixin`
hummerichsander Jul 1, 2024
722b3c0
Add: basic test for ProcessRagTask and AsyncChatClient
hummerichsander Jul 8, 2024
1fbcc8f
`close_old_connections` in `process_rag_task`
hummerichsander Jul 9, 2024
001fa95
Set defaults for RAG_TASK_BATCH_SIZE and RAG_LLM_CONCURRENCY_LIMIT to 1
hummerichsander Jul 9, 2024
2435e44
remove `ChatClient` and use `AsyncChatClient` in`report_chat_view`
hummerichsander Jul 16, 2024
653ff1f
remove `self._client` from `ProcessRagTask`
hummerichsander Jul 16, 2024
1854658
add `ModalitiesFactory` and apply formatting and linting
hummerichsander Jul 16, 2024
91460c2
set default values: `RAG_TASK_BATCH_SIZE = 64`and `RAG_LLM_CONCURRENC…
hummerichsander Jul 16, 2024
915c132
Rename `RagReportInstance` to `RagInstance` and replace ForeignKey f…
hummerichsander Jul 18, 2024
52efbb6
Merge remote-tracking branch 'origin/main' into rag-task-batching
hummerichsander Jul 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ services:
llamacpp_cpu:
<<: *llamacpp
image: ghcr.io/ggerganov/llama.cpp:server
entrypoint: "/bin/bash -c '/llama-server -mu $${LLM_MODEL_URL} -c 512 --host 0.0.0.0 --port 8080'"
entrypoint: "/bin/bash -c '/llama-server -mu $${LLM_MODEL_URL} -c 4096 --host 0.0.0.0 --port 8080 --threads 8 --threads-http 8 --parallel 8 --cont-batching'"
profiles: ["cpu"]

llamacpp_gpu:
<<: *llamacpp
image: ghcr.io/ggerganov/llama.cpp:server-cuda
entrypoint: "/bin/bash -c '/llama-server -mu $${LLM_MODEL_URL} -ngl 99 -c 4096 --host 0.0.0.0 --port 8080'"
entrypoint: "/bin/bash -c '/llama-server -mu $${LLM_MODEL_URL} -ngl 99 -c 4096 --host 0.0.0.0 --port 8080 --threads 8 --threads-http 8 --parallel 8 --cont-batching'"
deploy:
resources:
reservations:
Expand Down
58 changes: 58 additions & 0 deletions radis/core/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Protocol

from adit_radis_shared.common.mixins import ViewProtocol
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator
from django.db.models.query import QuerySet
from django.http import HttpRequest


# TODO: Move this to adit_radis_shared package. PR: https://github.com/openradx/adit-radis-shared/pull/5
class RelatedPaginationMixinProtocol(ViewProtocol, Protocol):
request: HttpRequest
object_list: QuerySet
paginate_by: int

def get_object(self) -> Any: ...

def get_context_data(self, **kwargs) -> dict[str, Any]: ...

def get_related_queryset(self) -> QuerySet: ...


class RelatedPaginationMixin:
"""This mixin provides pagination for a related queryset. This makes it possible to
paginate a related queryset in a DetailView. The related queryset is obtained by
the `get_related_queryset()` method that must be implemented by the subclass.
If used in combination with `RelatedFilterMixin`, the `RelatedPaginationMixin` must be
inherited first."""

def get_related_queryset(self: RelatedPaginationMixinProtocol) -> QuerySet:
raise NotImplementedError("You must implement this method")

def get_context_data(self: RelatedPaginationMixinProtocol, **kwargs):
context = super().get_context_data(**kwargs)

if "object_list" in context:
queryset = context["object_list"]
else:
queryset = self.get_related_queryset()

paginator = Paginator(queryset, self.paginate_by)
page = self.request.GET.get("page")

if page is None:
page = 1

try:
paginated_queryset = paginator.page(page)
except PageNotAnInteger:
paginated_queryset = paginator.page(1)
except EmptyPage:
paginated_queryset = paginator.page(paginator.num_pages)

context["object_list"] = paginated_queryset
context["paginator"] = paginator
context["is_paginated"] = paginated_queryset.has_other_pages()
context["page_obj"] = paginated_queryset

return context
Empty file.
32 changes: 32 additions & 0 deletions radis/core/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio
from typing import Callable, ContextManager
from unittest.mock import MagicMock

import pytest
from faker import Faker


@pytest.fixture
def report_body() -> str:
report_body = Faker().sentences(nb=40)
return " ".join(report_body)


@pytest.fixture
def question_body() -> str:
question_body = Faker().sentences(nb=1)
return " ".join(question_body)


@pytest.fixture
def openai_chat_completions_mock() -> Callable[[str], ContextManager]:
def _openai_chat_completions_mock(content: str) -> ContextManager:
mock_openai = MagicMock()
mock_response = MagicMock(choices=[MagicMock(message=MagicMock(content=content))])
future = asyncio.Future()
future.set_result(mock_response)
mock_openai.chat.completions.create.return_value = future

return mock_openai

return _openai_chat_completions_mock
34 changes: 34 additions & 0 deletions radis/core/tests/unit/test_chat_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest.mock import patch

import pytest

from radis.core.utils.chat_client import AsyncChatClient


@pytest.mark.asyncio
async def test_ask_question(report_body, question_body, openai_chat_completions_mock):
openai_mock = openai_chat_completions_mock("Fake Answer")

with patch("openai.AsyncOpenAI", return_value=openai_mock):
answer = await AsyncChatClient().ask_question(report_body, "en", question_body)

assert answer == "Fake Answer"
assert openai_mock.chat.completions.create.call_count == 1


@pytest.mark.asyncio
async def test_ask_yes_no_question(report_body, question_body, openai_chat_completions_mock):
openai_yes_mock = openai_chat_completions_mock("Yes")
openai_no_mock = openai_chat_completions_mock("No")

with patch("openai.AsyncOpenAI", return_value=openai_yes_mock):
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body)

assert answer == "yes"
assert openai_yes_mock.chat.completions.create.call_count == 1

with patch("openai.AsyncOpenAI", return_value=openai_no_mock):
answer = await AsyncChatClient().ask_yes_no_question(report_body, "en", question_body)

assert answer == "no"
assert openai_no_mock.chat.completions.create.call_count == 1
20 changes: 10 additions & 10 deletions radis/core/utils/chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from string import Template
from typing import Literal

import openai
from django.conf import settings
from openai import OpenAI

logger = logging.getLogger(__name__)

Expand All @@ -13,27 +13,27 @@
"""


class ChatClient:
class AsyncChatClient:
def __init__(self):
self._client = OpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")
self._client = openai.AsyncOpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")

def ask_question(
async def ask_question(
self, report_body: str, language: str, question: str, grammar: str | None = None
) -> str:
system_prompt = settings.CHAT_SYSTEM_PROMPT[language]
system = settings.CHAT_SYSTEM_PROMPT[language]
user_prompt = Template(settings.CHAT_USER_PROMPT[language]).substitute(
{"report": report_body, "question": question}
)

log_msg = f"Sending to LLM:\n[System] {system_prompt}\n[User] {user_prompt}"
log_msg = f"Sending to LLM:\n[System] {system}\n[User] {user_prompt}"
if grammar:
log_msg += f"\n[Grammar] {grammar}"
logger.debug(log_msg)

completion = self._client.chat.completions.create(
completion = await self._client.chat.completions.create(
model="none",
messages=[
{"role": "system", "content": system_prompt},
{"role": "system", "content": system},
{"role": "user", "content": user_prompt},
],
extra_body={"grammar": grammar},
Expand All @@ -45,7 +45,7 @@ def ask_question(

return answer

def ask_yes_no_question(
async def ask_yes_no_question(
self, report_body: str, language: str, question: str
) -> Literal["yes", "no"]:
grammar = Template(YES_NO_GRAMMAR).substitute(
Expand All @@ -55,7 +55,7 @@ def ask_yes_no_question(
}
)

llm_answer = self.ask_question(report_body, language, question, grammar)
llm_answer = await self.ask_question(report_body, language, question, grammar)

if llm_answer == settings.CHAT_ANSWER_YES[language]:
return "yes"
Expand Down
23 changes: 18 additions & 5 deletions radis/rag/admin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django import forms
from django.contrib import admin

from .models import Question, QuestionResult, RagJob, RagTask
from .models import Question, QuestionResult, RagInstance, RagJob, RagTask


class QuestionInline(admin.StackedInline):
Expand All @@ -17,11 +17,24 @@ class RagJobAdmin(admin.ModelAdmin):
admin.site.register(RagJob, RagJobAdmin)


class RagInstanceInline(admin.StackedInline):
model = RagInstance
extra = 1
ordering = ("id",)


class RagTaskAdmin(admin.ModelAdmin):
inlines = [RagInstanceInline]


admin.site.register(RagTask, RagTaskAdmin)


class QuestionResultInlineFormset(forms.BaseInlineFormSet):
def add_fields(self, form: forms.Form, index: int) -> None:
super().add_fields(form, index)
task: RagTask = self.instance
form.fields["question"].queryset = task.job.questions.all()
rag_instance = self.instance
form.fields["question"].queryset = rag_instance.task.job.questions.all()


class QuestionResultInline(admin.StackedInline):
Expand All @@ -31,8 +44,8 @@ class QuestionResultInline(admin.StackedInline):
formset = QuestionResultInlineFormset


class RagTaskAdmin(admin.ModelAdmin):
class RagInstanceAdmin(admin.ModelAdmin):
inlines = [QuestionResultInline]


admin.site.register(RagTask, RagTaskAdmin)
admin.site.register(RagInstance, RagInstanceAdmin)
91 changes: 91 additions & 0 deletions radis/rag/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Generic, TypeVar, cast

import factory
from faker import Faker

from radis.reports.factories import ModalityFactory

from .models import Answer, Question, RagInstance, RagJob, RagTask

T = TypeVar("T")

fake = Faker()


class BaseDjangoModelFactory(Generic[T], factory.django.DjangoModelFactory):
@classmethod
def create(cls, *args, **kwargs) -> T:
return super().create(*args, **kwargs)


SearchProviders = ("OpenSearch", "Vespa", "Elasticsearch")
PatientSexes = ["", "M", "F"]


class RagJobFactory(BaseDjangoModelFactory):
class Meta:
model = RagJob

title = factory.Faker("sentence", nb_words=3)
provider = factory.Faker("random_element", elements=SearchProviders)
group = factory.SubFactory("adit_radis_shared.accounts.factories.GroupFactory")
query = factory.Faker("word")
language = factory.SubFactory("radis.reports.factories.LanguageFactory")
study_date_from = factory.Faker("date")
study_date_till = factory.Faker("date")
study_description = factory.Faker("sentence", nb_words=5)
patient_sex = factory.Faker("random_element", elements=PatientSexes)
age_from = factory.Faker("random_int", min=0, max=100)
age_till = factory.Faker("random_int", min=0, max=100)

@factory.post_generation
def modalities(self, create, extracted, **kwargs):
if not create:
return

self = cast(RagJob, self)

if extracted:
for modality in extracted:
self.modalities.add(modality)
else:
modality = ModalityFactory()
self.modalities.add(modality)


class QuestionFactory(BaseDjangoModelFactory[Question]):
class Meta:
model = Question

job = factory.SubFactory("radis.rag.factories.RagJobFactory")
question = factory.Faker("sentence", nb_words=10)
accepted_answer = factory.Faker("random_element", elements=[a[0] for a in Answer.choices])


class RagTaskFactory(BaseDjangoModelFactory[RagTask]):
class Meta:
model = RagTask

job = factory.SubFactory("radis.rag.factories.RagJobFactory")


class RagInstanceFactory(BaseDjangoModelFactory[RagInstance]):
class Meta:
model = RagInstance

task = factory.SubFactory("radis.rag.factories.RagTaskFactory")

@factory.post_generation
def reports(self, create, extracted, **kwargs):
if not create:
return

self = cast(RagInstance, self)

if extracted:
for report in extracted:
self.reports.add(report)
else:
from radis.reports.factories import ReportFactory

self.reports.add(*[ReportFactory() for _ in range(3)])
4 changes: 2 additions & 2 deletions radis/rag/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from radis.core.filters import AnalysisJobFilter, AnalysisTaskFilter

from .models import RagJob, RagTask
from .models import RagInstance, RagJob, RagTask


class RagJobFilter(AnalysisJobFilter):
Expand All @@ -21,7 +21,7 @@ class RagResultFilter(django_filters.FilterSet):
request: HttpRequest

class Meta:
model = RagTask
model = RagInstance
fields = ("overall_result",)

def __init__(self, *args, **kwargs):
Expand Down
Loading
Loading