Skip to content

Commit

Permalink
Merge pull request #7 from openradx/chat
Browse files Browse the repository at this point in the history
Add report chat
  • Loading branch information
medihack committed Jun 4, 2024
2 parents 6b7d7e9 + 4856276 commit cfc0840
Show file tree
Hide file tree
Showing 13 changed files with 315 additions and 137 deletions.
2 changes: 1 addition & 1 deletion .gitpod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ tasks:
init: |
poetry install
poetry run invoke init-workspace
poetry run invoke download-llm -m mistral-7b-q1
poetry run invoke download-llm -m tinyllama-1b-q2
ports:
- port: 8000
Expand Down
162 changes: 81 additions & 81 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = ["medihack <[email protected]>"]
license = "GPL-3.0-or-later"

[tool.poetry.dependencies]
adit-radis-shared = { git = "https://github.com/openradx/adit-radis-shared.git", tag = "v0.3.5" }
adit-radis-shared = { git = "https://github.com/openradx/adit-radis-shared.git", tag = "v0.3.8" }
adrf = "^0.1.4"
aiofiles = "^23.1.0"
asyncinotify = "^4.0.1"
Expand Down Expand Up @@ -43,7 +43,7 @@ redis = "^5.0.3"
toml = "^0.10.2"
Twisted = { extras = ["tls", "http2"], version = "^24.3.0" }
wait-for-it = "^2.2.2"
watchfiles = "^0.21.0"
watchfiles = "^0.22.0"
whitenoise = "^6.0.0"

[tool.poetry.group.dev.dependencies]
Expand Down
19 changes: 19 additions & 0 deletions radis/core/static/core/core.css
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
/*
* HTMX overrides
*/
.my-indicator {
display: none;
}
.htmx-request .my-indicator {
display: inline;
opacity: 1;
}
.htmx-request.my-indicator {
display: inline;
opacity: 1;
}

/*
* Misc
*/

.search-summary {
white-space: normal;
}
Expand Down
65 changes: 65 additions & 0 deletions radis/core/utils/chat_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
from string import Template
from typing import Literal

from django.conf import settings
from openai import OpenAI

logger = logging.getLogger(__name__)

YES_NO_GRAMMAR = """
root ::= Answer
Answer ::= "$yes" | "$no"
"""


class ChatClient:
def __init__(self):
self._client = OpenAI(base_url=f"{settings.LLAMACPP_URL}/v1", api_key="none")

def ask_question(
self, report_body: str, language: str, question: str, grammar: str | None = None
) -> str:
system_prompt = settings.CHAT_SYSTEM_PROMPT[language]
user_prompt = Template(settings.CHAT_USER_PROMPT[language]).substitute(
{"report": report_body, "question": question}
)

log_msg = f"Sending to LLM:\n[System] {system_prompt}\n[User] {user_prompt}"
if grammar:
log_msg += f"\n[Grammar] {grammar}"
logger.debug(log_msg)

completion = self._client.chat.completions.create(
model="none",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
extra_body={"grammar": grammar},
)

answer = completion.choices[0].message.content
assert answer is not None
logger.debug("Received from LLM: %s", answer)

return answer

def ask_yes_no_question(
self, report_body: str, language: str, question: str
) -> Literal["yes", "no"]:
grammar = Template(YES_NO_GRAMMAR).substitute(
{
"yes": settings.CHAT_ANSWER_YES[language],
"no": settings.CHAT_ANSWER_NO[language],
}
)

llm_answer = self.ask_question(report_body, language, question, grammar)

if llm_answer == settings.CHAT_ANSWER_YES[language]:
return "yes"
elif llm_answer == settings.CHAT_ANSWER_NO[language]:
return "no"
else:
raise ValueError(f"Unexpected answer: {llm_answer}")
40 changes: 5 additions & 35 deletions radis/rag/tasks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from string import Template
from typing import Iterator, override

from django.conf import settings
from openai import OpenAI

from radis.celery import app as celery_app
from radis.core.tasks import ProcessAnalysisJob, ProcessAnalysisTask
from radis.core.utils.chat_client import ChatClient
from radis.reports.models import Report
from radis.search.site import Search, SearchFilters
from radis.search.utils.query_parser import QueryParser
Expand All @@ -16,11 +16,6 @@

logger = logging.getLogger(__name__)

GRAMMAR = """
root ::= Answer
Answer ::= "$yes" | "$no"
"""


class ProcessRagTask(ProcessAnalysisTask):
analysis_task_class = RagTask
Expand All @@ -40,39 +35,14 @@ def process_task(self, task: RagTask) -> None:

all_results: list[RagTask.Result] = []

system_prompt = settings.RAG_SYSTEM_PROMPT[language]
logger.debug("Using system prompt:\n%s", system_prompt)

grammar = Template(GRAMMAR).substitute(
{
"yes": settings.RAG_ANSWER_YES[language],
"no": settings.RAG_ANSWER_NO[language],
}
)
logger.debug("Using grammar:\n%s", grammar)
chat_client = ChatClient()

for question in task.job.questions.all():
user_prompt = Template(settings.RAG_USER_PROMPT[language]).substitute(
{"report": report_body, "question": question.question}
)

logger.debug("Sending user prompt:\n%s", user_prompt)

completion = self._client.chat.completions.create(
model="none",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
extra_body={"grammar": grammar},
)

llm_answer = completion.choices[0].message.content
logger.debug("Received answer by LLM: %s", llm_answer)
llm_answer = chat_client.ask_yes_no_question(report_body, language, question.question)

if llm_answer == settings.RAG_ANSWER_YES[language]:
if llm_answer == "yes":
answer = Answer.YES
elif llm_answer == settings.RAG_ANSWER_NO[language]:
elif llm_answer == "no":
answer = Answer.NO
else:
raise ValueError(f"Unexpected answer: {llm_answer}")
Expand Down
27 changes: 27 additions & 0 deletions radis/reports/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from crispy_forms.helper import FormHelper, Layout
from crispy_forms.layout import Field, Submit
from django import forms


class PromptForm(forms.Form):
prompt = forms.CharField(max_length=500)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.helper = FormHelper()
self.helper.form_show_labels = False
self.helper.form_tag = False
self.helper.layout = Layout(
Field("prompt", placeholder="Ask the LLM a question about this report"),
Submit(
"yes_no_answer",
"Yes/No answer",
css_class="btn-primary",
),
Submit(
"full_answer",
"Full answer",
css_class="btn-primary",
),
)
37 changes: 37 additions & 0 deletions radis/reports/templates/reports/_report_chat.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{% load crispy from crispy_forms_tags %}
<div id="report_chat" class="mt-3 mx-2">
<div id="messages-container">
{% if messages %}
<table class="table table-borderless">
<tbody>
{% for message in messages %}
<tr>
<th scope="row">{{ message.role }}</th>
<td class="w-100">{{ message.content }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endif %}
</div>
<div id="loader" class="my-indicator htmx-indicator">
<div class="mb-3 d-flex justify-content-center gap-1">
<div class="spinner-grow spinner-grow-sm" role="status">
<span class="visually-hidden">Loading...</span>
</div>
<div class="spinner-grow spinner-grow-sm" role="status">
<span class="visually-hidden">Loading...</span>
</div>
<div class="spinner-grow spinner-grow-sm" role="status">
<span class="visually-hidden">Loading...</span>
</div>
</div>
</div>
<form hx-post="{% url 'report_chat' report.id %}"
hx-on:submit="htmx.addClass(htmx.find('#messages-container'), 'd-none')"
hx-target="#report_chat"
hx-indicator="#loader"
hx-disabled-elt="input">
{% crispy prompt_form %}
</form>
</div>
1 change: 1 addition & 0 deletions radis/reports/templates/reports/report_detail.html
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ <h5 class="card-title">Report Text</h5>
<div class="pre-line">{{ report.body }}</div>
</div>
</div>
{% include "reports/_report_chat.html" %}
{% endblock content %}
3 changes: 2 additions & 1 deletion radis/reports/urls.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from django.urls import path
from rest_framework.urlpatterns import format_suffix_patterns

from .views import ReportBodyView, ReportDetailView
from .views import ReportBodyView, ReportDetailView, report_chat_view

urlpatterns = [
path("<int:pk>/body/", ReportBodyView.as_view(), name="report_body"),
path("<int:pk>/chat/", report_chat_view, name="report_chat"),
path("<int:pk>/", ReportDetailView.as_view(), name="report_detail"),
]

Expand Down
54 changes: 54 additions & 0 deletions radis/reports/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Any

from adit_radis_shared.common.decorators import login_required_async, user_passes_test_async
from adit_radis_shared.common.types import AuthenticatedHttpRequest
from asgiref.sync import sync_to_async
from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin
from django.db.models import QuerySet
from django.http import HttpResponse
from django.shortcuts import aget_object_or_404, render # type: ignore
from django.views.generic.detail import DetailView

from radis.core.utils.chat_client import ChatClient
from radis.reports.forms import PromptForm

from .models import Report


Expand All @@ -21,6 +30,51 @@ def get_queryset(self) -> QuerySet[Report]:
assert active_group
return super().get_queryset().filter(groups=active_group)

def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
context = super().get_context_data(**kwargs)
context["prompt_form"] = PromptForm()
context["messages"] = []
return context


class ReportBodyView(ReportDetailView):
template_name = "reports/report_body.html"


@login_required_async
@user_passes_test_async(lambda user: user.active_group is not None)
async def report_chat_view(request: AuthenticatedHttpRequest, pk: int) -> HttpResponse:
report = await aget_object_or_404(
Report.objects.filter(groups=request.user.active_group), pk=pk
)

language = await sync_to_async(lambda: report.language)()
form = PromptForm(request.POST)

context: dict[str, Any] = {
"messages": [],
"report": report,
"prompt_form": form,
}

if form.is_valid():
chat_client = ChatClient()
if request.POST.get("yes_no_answer"):
answer = chat_client.ask_yes_no_question(
report.body, language.code, form.cleaned_data["prompt"]
)
answer = "Yes" if answer == "yes" else "No"
elif request.POST.get("full_answer"):
answer = chat_client.ask_question(
report.body, language.code, form.cleaned_data["prompt"]
)
else:
raise ValueError("Invalid form")

context["messages"] = [
{"role": "User", "content": form.cleaned_data["prompt"]},
{"role": "Assistant", "content": answer},
]
context["prompt_form"] = PromptForm()

return render(request, "reports/_report_chat.html", context)
24 changes: 13 additions & 11 deletions radis/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,36 +374,38 @@
VESPA_CONFIG_PORT = env.int("VESPA_CONFIG_PORT", default=19071) # type: ignore
VESPA_DATA_PORT = env.int("VESPA_DATA_PORT", default=8080) # type: ignore

# RAG
RAG_DEFAULT_PRIORITY = 2
RAG_URGENT_PRIORITY = 3

START_RAG_JOB_UNVERIFIED = False
RAG_SYSTEM_PROMPT = {
# Chat settings
CHAT_SYSTEM_PROMPT = {
"de": "Du bist ein radiologischer Facharzt",
"en": "You are a radiologist",
}
RAG_ANSWER_YES = {
CHAT_ANSWER_YES = {
"de": "Ja",
"en": "Yes",
}
RAG_ANSWER_NO = {
CHAT_ANSWER_NO = {
"de": "Nein",
"en": "No",
}
RAG_USER_PROMPT = {
CHAT_USER_PROMPT = {
"de": f"""
Im folgenden erhälst Du einen radiologischen Befund und eine Frage zu diesem Befund.
Beantworte die Frage zu dem Befund mit {RAG_ANSWER_YES['de']} oder {RAG_ANSWER_NO['de']}.
Beantworte die Frage zu dem Befund mit {CHAT_ANSWER_YES['de']} oder {CHAT_ANSWER_NO['de']}.
Befund: $report
Frage: $question
Antwort:
""",
"en": f"""
In the following you will find a radiological report and a question about this report.
Answer the question about the report with {RAG_ANSWER_YES['en']} or {RAG_ANSWER_NO['en']}.
Answer the question about the report with {CHAT_ANSWER_YES['en']} or {CHAT_ANSWER_NO['en']}.
Report: $report
Question: $question
Answer:
""",
}

# RAG
RAG_DEFAULT_PRIORITY = 2
RAG_URGENT_PRIORITY = 3

START_RAG_JOB_UNVERIFIED = False
Loading

0 comments on commit cfc0840

Please sign in to comment.