diff --git a/radis/rag/migrations/0013_raginstance_text.py b/radis/rag/migrations/0013_raginstance_text.py new file mode 100644 index 00000000..2074c9dc --- /dev/null +++ b/radis/rag/migrations/0013_raginstance_text.py @@ -0,0 +1,19 @@ +# Generated by Django 5.0.7 on 2024-08-03 21:34 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('rag', '0012_switch_to_procrastinate'), + ] + + operations = [ + migrations.AddField( + model_name='raginstance', + name='text', + field=models.TextField(default=''), + preserve_default=False, + ), + ] diff --git a/radis/rag/migrations/0014_other_reports_and_more.py b/radis/rag/migrations/0014_other_reports_and_more.py new file mode 100644 index 00000000..a2b1489a --- /dev/null +++ b/radis/rag/migrations/0014_other_reports_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 5.0.7 on 2024-08-04 22:19 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('rag', '0013_raginstance_text'), + ('reports', '0011_report_patient_age'), + ] + + operations = [ + migrations.RenameField( + model_name='raginstance', + old_name='reports', + new_name='other_reports', + ), + migrations.AddField( + model_name='raginstance', + name='report', + field=models.ForeignKey(default=None, on_delete=django.db.models.deletion.CASCADE, related_name='rag_instances', to='reports.report'), + preserve_default=False, + ), + ] diff --git a/radis/rag/models.py b/radis/rag/models.py index e8dbdd74..fcb43b05 100644 --- a/radis/rag/models.py +++ b/radis/rag/models.py @@ -109,7 +109,10 @@ class Result(models.TextChoices): results = RelatedManager["QuestionResult"]() id: int - reports = models.ManyToManyField(Report) + text = models.TextField() + report_id: int + report = models.ForeignKey(Report, on_delete=models.CASCADE, related_name="rag_instances") + other_reports = models.ManyToManyField(Report) overall_result = models.CharField(max_length=1, choices=Result.choices, blank=True) get_overall_result_display: Callable[[], str] task = models.ForeignKey(RagTask, on_delete=models.CASCADE, related_name="rag_instances") diff --git a/radis/rag/processors.py b/radis/rag/processors.py index 14fff2c6..8e8570b4 100644 --- a/radis/rag/processors.py +++ b/radis/rag/processors.py @@ -5,11 +5,10 @@ from channels.db import database_sync_to_async from django import db from django.conf import settings -from django.db.models.query import QuerySet +from django.db.models import Prefetch from radis.core.processors import AnalysisTaskProcessor from radis.core.utils.chat_client import AsyncChatClient -from radis.reports.models import Report from .models import Answer, Question, QuestionResult, RagInstance, RagTask @@ -18,32 +17,33 @@ class RagTaskProcessor(AnalysisTaskProcessor): async def process_task(self, task: RagTask) -> None: + language_code = task.job.language.code client = AsyncChatClient() sem = Semaphore(settings.RAG_LLM_CONCURRENCY_LIMIT) await asyncio.gather( *[ - self.process_rag_instance(rag_instance, client, sem) - async for rag_instance in task.rag_instances.prefetch_related("reports") + self.process_rag_instance(rag_instance, language_code, client, sem) + async for rag_instance in task.rag_instances.prefetch_related( + Prefetch("report"), Prefetch("other_reports") + ) ] ) await database_sync_to_async(db.close_old_connections)() async def process_rag_instance( - self, rag_instance: RagInstance, client: AsyncChatClient, sem: Semaphore + self, rag_instance: RagInstance, language_code: str, client: AsyncChatClient, sem: Semaphore ) -> None: - report = await self.combine_reports(rag_instance.reports.prefetch_related("language")) - language = report.language + rag_instance.text = await self.get_text_to_analyze(rag_instance) + await rag_instance.asave() - if language.code not in settings.SUPPORTED_LANGUAGES: - raise ValueError(f"Language '{language}' is not supported.") + if language_code not in settings.SUPPORTED_LANGUAGES: + raise ValueError(f"Language '{language_code}' is not supported.") async with sem: results = await asyncio.gather( *[ - self.process_yes_or_no_question( - rag_instance, report.body, language.code, question, client - ) + self.process_yes_or_no_question(rag_instance, language_code, question, client) async for question in rag_instance.task.job.questions.all() ] ) @@ -62,26 +62,27 @@ async def process_rag_instance( rag_instance.get_overall_result_display(), ) - async def combine_reports(self, reports: QuerySet[Report]) -> Report: - count = await reports.acount() - if count > 1: - raise ValueError("Multiple reports is not yet supported") + async def get_text_to_analyze(self, rag_instance: RagInstance) -> str: + text_to_analyze = "" + text_to_analyze += rag_instance.report.body - report = await reports.afirst() - if report is None: - raise ValueError("No reports to combine") + async for report in rag_instance.other_reports.order_by("study_datetime").all(): + if text_to_analyze: + text_to_analyze += "\n\n" + text_to_analyze += report.body - return report + return text_to_analyze async def process_yes_or_no_question( self, rag_instance: RagInstance, - body: str, language: str, question: Question, client: AsyncChatClient, ) -> RagInstance.Result: - llm_answer = await client.ask_yes_no_question(body, language, question.question) + llm_answer = await client.ask_yes_no_question( + rag_instance.text, language, question.question + ) if llm_answer == "yes": answer = Answer.YES diff --git a/radis/rag/tasks.py b/radis/rag/tasks.py index be8e98e1..3b9da83c 100644 --- a/radis/rag/tasks.py +++ b/radis/rag/tasks.py @@ -2,6 +2,7 @@ from itertools import batched from django.conf import settings +from django.db.models import Prefetch from procrastinate.contrib.django import app from radis.reports.models import Report @@ -17,7 +18,12 @@ @app.task(queue="llm") async def process_rag_task(task_id: int) -> None: - task = await RagTask.objects.prefetch_related("job").aget(id=task_id) + task = await RagTask.objects.prefetch_related( + Prefetch( + "job", + queryset=RagJob.objects.prefetch_related("language"), + ) + ).aget(id=task_id) processor = RagTaskProcessor(task) await processor.start() @@ -64,9 +70,10 @@ def process_rag_job(job_id: int) -> None: for document_ids in batched(retrieval_provider.retrieve(search), settings.RAG_TASK_BATCH_SIZE): logger.debug("Creating RAG task for document IDs: %s", document_ids) task = RagTask.objects.create(job=job, status=RagTask.Status.PENDING) + for document_id in document_ids: - rag_instance = RagInstance.objects.create(task=task) - rag_instance.reports.add(Report.objects.get(document_id=document_id)) + report = Report.objects.get(document_id=document_id) + RagInstance.objects.create(task=task, report_id=report.id) task.delay() diff --git a/radis/rag/templates/rag/_rag_result_summary.html b/radis/rag/templates/rag/_rag_result_summary.html index f74dc040..28ee238d 100644 --- a/radis/rag/templates/rag/_rag_result_summary.html +++ b/radis/rag/templates/rag/_rag_result_summary.html @@ -1,22 +1,23 @@