Skip to content

Commit

Permalink
Introduce rag instance text and other reports
Browse files Browse the repository at this point in the history
  • Loading branch information
medihack committed Aug 5, 2024
1 parent 66895de commit 13551a4
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 36 deletions.
19 changes: 19 additions & 0 deletions radis/rag/migrations/0013_raginstance_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Generated by Django 5.0.7 on 2024-08-03 21:34

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('rag', '0012_switch_to_procrastinate'),
]

operations = [
migrations.AddField(
model_name='raginstance',
name='text',
field=models.TextField(default=''),
preserve_default=False,
),
]
26 changes: 26 additions & 0 deletions radis/rag/migrations/0014_other_reports_and_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 5.0.7 on 2024-08-04 22:19

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('rag', '0013_raginstance_text'),
('reports', '0011_report_patient_age'),
]

operations = [
migrations.RenameField(
model_name='raginstance',
old_name='reports',
new_name='other_reports',
),
migrations.AddField(
model_name='raginstance',
name='report',
field=models.ForeignKey(default=None, on_delete=django.db.models.deletion.CASCADE, related_name='rag_instances', to='reports.report'),
preserve_default=False,
),
]
5 changes: 4 additions & 1 deletion radis/rag/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ class Result(models.TextChoices):
results = RelatedManager["QuestionResult"]()

id: int
reports = models.ManyToManyField(Report)
text = models.TextField()
report_id: int
report = models.ForeignKey(Report, on_delete=models.CASCADE, related_name="rag_instances")
other_reports = models.ManyToManyField(Report)
overall_result = models.CharField(max_length=1, choices=Result.choices, blank=True)
get_overall_result_display: Callable[[], str]
task = models.ForeignKey(RagTask, on_delete=models.CASCADE, related_name="rag_instances")
Expand Down
45 changes: 23 additions & 22 deletions radis/rag/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from channels.db import database_sync_to_async
from django import db
from django.conf import settings
from django.db.models.query import QuerySet
from django.db.models import Prefetch

from radis.core.processors import AnalysisTaskProcessor
from radis.core.utils.chat_client import AsyncChatClient
from radis.reports.models import Report

from .models import Answer, Question, QuestionResult, RagInstance, RagTask

Expand All @@ -18,32 +17,33 @@

class RagTaskProcessor(AnalysisTaskProcessor):
async def process_task(self, task: RagTask) -> None:
language_code = task.job.language.code
client = AsyncChatClient()
sem = Semaphore(settings.RAG_LLM_CONCURRENCY_LIMIT)

await asyncio.gather(
*[
self.process_rag_instance(rag_instance, client, sem)
async for rag_instance in task.rag_instances.prefetch_related("reports")
self.process_rag_instance(rag_instance, language_code, client, sem)
async for rag_instance in task.rag_instances.prefetch_related(
Prefetch("report"), Prefetch("other_reports")
)
]
)
await database_sync_to_async(db.close_old_connections)()

async def process_rag_instance(
self, rag_instance: RagInstance, client: AsyncChatClient, sem: Semaphore
self, rag_instance: RagInstance, language_code: str, client: AsyncChatClient, sem: Semaphore
) -> None:
report = await self.combine_reports(rag_instance.reports.prefetch_related("language"))
language = report.language
rag_instance.text = await self.get_text_to_analyze(rag_instance)
await rag_instance.asave()

if language.code not in settings.SUPPORTED_LANGUAGES:
raise ValueError(f"Language '{language}' is not supported.")
if language_code not in settings.SUPPORTED_LANGUAGES:
raise ValueError(f"Language '{language_code}' is not supported.")

async with sem:
results = await asyncio.gather(
*[
self.process_yes_or_no_question(
rag_instance, report.body, language.code, question, client
)
self.process_yes_or_no_question(rag_instance, language_code, question, client)
async for question in rag_instance.task.job.questions.all()
]
)
Expand All @@ -62,26 +62,27 @@ async def process_rag_instance(
rag_instance.get_overall_result_display(),
)

async def combine_reports(self, reports: QuerySet[Report]) -> Report:
count = await reports.acount()
if count > 1:
raise ValueError("Multiple reports is not yet supported")
async def get_text_to_analyze(self, rag_instance: RagInstance) -> str:
text_to_analyze = ""
text_to_analyze += rag_instance.report.body

report = await reports.afirst()
if report is None:
raise ValueError("No reports to combine")
async for report in rag_instance.other_reports.order_by("study_datetime").all():
if text_to_analyze:
text_to_analyze += "\n\n"
text_to_analyze += report.body

return report
return text_to_analyze

async def process_yes_or_no_question(
self,
rag_instance: RagInstance,
body: str,
language: str,
question: Question,
client: AsyncChatClient,
) -> RagInstance.Result:
llm_answer = await client.ask_yes_no_question(body, language, question.question)
llm_answer = await client.ask_yes_no_question(
rag_instance.text, language, question.question
)

if llm_answer == "yes":
answer = Answer.YES
Expand Down
13 changes: 10 additions & 3 deletions radis/rag/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import batched

from django.conf import settings
from django.db.models import Prefetch
from procrastinate.contrib.django import app

from radis.reports.models import Report
Expand All @@ -17,7 +18,12 @@

@app.task(queue="llm")
async def process_rag_task(task_id: int) -> None:
task = await RagTask.objects.prefetch_related("job").aget(id=task_id)
task = await RagTask.objects.prefetch_related(
Prefetch(
"job",
queryset=RagJob.objects.prefetch_related("language"),
)
).aget(id=task_id)
processor = RagTaskProcessor(task)
await processor.start()

Expand Down Expand Up @@ -64,9 +70,10 @@ def process_rag_job(job_id: int) -> None:
for document_ids in batched(retrieval_provider.retrieve(search), settings.RAG_TASK_BATCH_SIZE):
logger.debug("Creating RAG task for document IDs: %s", document_ids)
task = RagTask.objects.create(job=job, status=RagTask.Status.PENDING)

for document_id in document_ids:
rag_instance = RagInstance.objects.create(task=task)
rag_instance.reports.add(Report.objects.get(document_id=document_id))
report = Report.objects.get(document_id=document_id)
RagInstance.objects.create(task=task, report_id=report.id)

task.delay()

Expand Down
35 changes: 29 additions & 6 deletions radis/rag/templates/rag/_rag_result_summary.html
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
<div class="card mb-3">
<div class="card-body" x-data="{full: false, questions: false}">
<div class="card-body"
x-data="{full: false, questions: false, reports: false}">
<div class="d-flex flex-column gap-1">
<div class="d-flex justify-content-between">
{% include "reports/_report_header.html" with report=rag_instance.reports.first %}
{% include "reports/_report_header.html" with report=rag_instance.report %}
<div>{% include "rag/_overall_result_badge.html" %}</div>
</div>
<div class="d-flex flex-column gap-1">
<div class="clamp-3" :class="{'clamp-3': !full, 'pre-line': full}">{{ rag_instance.reports.first.body }}</div>
<div class="clamp-3" :class="{'clamp-3': !full, 'pre-line': full}">{{ rag_instance.text }}</div>
<div class="d-flex gap-2">
<button type="button"
class="btn btn-sm btn-link p-0 border-0"
@click.prevent="full=true"
x-show="!full">[Show full report]</button>
x-show="!full">[Show full text]</button>
<button type="button"
class="btn btn-sm btn-link p-0 border-0"
@click.prevent="full=false"
x-cloak
x-show="full">[Show summary]</button>
x-show="full">[Collapse text]</button>
<button type="button"
class="btn btn-sm btn-link p-0 border-0"
@click.prevent="questions=true"
Expand All @@ -25,6 +26,16 @@
class="btn btn-sm btn-link p-0 border-0"
@click.prevent="questions=false"
x-show="questions">[Hide questions]</button>
{% if rag_instance.other_reports.exists %}
<button type="button"
class="btn btn-sm btn-link p-0 border-0"
@click.prevent="reports=true"
x-show="!reports">[Show other reports]</button>
<button type="button"
class="btn btn-sm btn-link p-0 border-0"
@click.prevent="reports=false"
x-show="reports">[Hide other reports]</button>
{% endif %}
</div>
</div>
<div x-cloak x-show="questions">
Expand All @@ -38,7 +49,19 @@
</div>
{% endfor %}
</div>
{% include "reports/_report_buttons_panel.html" with report=rag_instance.reports.first %}
{% if rag_instance.other_reports.exists %}
<div x-cloak x-show="reports">
Other reports:
<ul>
{% for report in rag_instance.other_reports.all %}
<li>
<a href="{% url 'report_detail' report.id %}">{{ report.document_id }}</a>
</li>
{% endfor %}
</ul>
</div>
{% endif %}
{% include "reports/_report_buttons_panel.html" with report=rag_instance.report %}
</div>
</div>
</div>
7 changes: 3 additions & 4 deletions radis/rag/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from django.core.exceptions import SuspiciousOperation
from django.db import transaction
from django.db.models import QuerySet
from django.db.models import Prefetch, QuerySet
from django.forms import BaseInlineFormSet
from django.shortcuts import redirect, render
from django.urls import reverse_lazy
Expand Down Expand Up @@ -254,14 +254,13 @@ def get_queryset(self) -> QuerySet[RagJob]:

def get_related_queryset(self) -> QuerySet[RagInstance]:
job = cast(RagJob, self.get_object())
rag_instances = RagInstance.objects.filter(
return RagInstance.objects.filter(
task__job=job,
overall_result__in=[
RagInstance.Result.ACCEPTED,
RagInstance.Result.REJECTED,
],
).prefetch_related("reports")
return rag_instances
).prefetch_related(Prefetch("report"), Prefetch("other_reports"))

def get_filter_queryset(self) -> QuerySet[RagInstance]:
return self.get_related_queryset()
Expand Down

0 comments on commit 13551a4

Please sign in to comment.