From 5c23feb386aef91fb0118ac8c0be6ecc1fedccb1 Mon Sep 17 00:00:00 2001 From: Kai Schlamp Date: Tue, 19 Mar 2024 22:25:17 +0000 Subject: [PATCH] Support batch create, update, (delete) of reports --- TODO.md | 4 ++ radis/core/management/commands/populate_db.py | 17 +++-- radis/reports/api/viewsets.py | 23 ++++-- radis/reports/site.py | 43 ++++++++--- radis/reports/tasks.py | 20 +++--- radis/vespa/apps.py | 29 +++----- radis/vespa/utils/document_utils.py | 72 ++++++++++++++----- 7 files changed, 139 insertions(+), 69 deletions(-) diff --git a/TODO.md b/TODO.md index e094a22d..0b0dd631 100644 --- a/TODO.md +++ b/TODO.md @@ -2,6 +2,10 @@ ## High Priority +- Use own worker for indexing +- Thinks about a better delete strategy +- Delete document_id from Vespa schema and extract it from the returned id + - - Check if for RAG ranking should be turned off for performance improvements (and using some fixed sort order) - Some present provider.max_results to the user somehow, especially important if the query results (step 1) is larger - task control panel diff --git a/radis/core/management/commands/populate_db.py b/radis/core/management/commands/populate_db.py index 03f87995..963e9787 100644 --- a/radis/core/management/commands/populate_db.py +++ b/radis/core/management/commands/populate_db.py @@ -12,7 +12,7 @@ from radis.accounts.models import User from radis.reports.factories import ReportFactory from radis.reports.models import Report -from radis.reports.site import report_event_handlers +from radis.reports.site import reports_created_handlers from radis.token_authentication.factories import TokenFactory from radis.token_authentication.models import FRACTION_LENGTH from radis.token_authentication.utils.crypto import hash_token @@ -24,12 +24,11 @@ fake = Faker() -def feed_report(body: str, language: Literal["en", "de"]): +def create_report(body: str, language: Literal["en", "de"]): report = ReportFactory.create(language=language, body=body) groups = fake.random_elements(elements=list(Group.objects.all()), unique=True) report.groups.set(groups) - for handler in report_event_handlers: - handler("created", report.document_id) + return report def feed_reports(language: Literal["en", "de"]): @@ -42,10 +41,14 @@ def feed_reports(language: Literal["en", "de"]): samples_path = Path(settings.BASE_DIR / "samples" / sample_file) with open(samples_path, "r") as f: - reports = json.load(f) + report_bodies = json.load(f) - for report in reports: - feed_report(report, language) + reports: list[Report] = [] + for report_body in report_bodies: + reports.append(create_report(report_body, language)) + + for handler in reports_created_handlers: + handler([report.id for report in reports]) def create_admin() -> User: diff --git a/radis/reports/api/viewsets.py b/radis/reports/api/viewsets.py index 0ce462fb..384556ca 100644 --- a/radis/reports/api/viewsets.py +++ b/radis/reports/api/viewsets.py @@ -9,7 +9,7 @@ from rest_framework.response import Response from rest_framework.serializers import BaseSerializer -from radis.reports.tasks import report_created, report_deleted, report_updated +from radis.reports.tasks import reports_created, reports_deleted, reports_updated from ..models import Report from ..site import document_fetchers @@ -33,6 +33,11 @@ class ReportViewSet( lookup_field = "document_id" permission_classes = [IsAdminUser] + def get_serializer(self, *args: Any, **kwargs: Any) -> BaseSerializer: + if isinstance(kwargs.get("data", {}), list): + kwargs["many"] = True + return super().get_serializer(*args, **kwargs) + def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response: """Retrieve a single Report. @@ -57,8 +62,11 @@ def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response: def perform_create(self, serializer: BaseSerializer) -> None: super().perform_create(serializer) assert serializer.instance - report: Report = serializer.instance - transaction.on_commit(lambda: report_created.delay(report.document_id)) + reports: list[Report] | Report = serializer.instance + if not isinstance(reports, list): + reports = [reports] + + transaction.on_commit(lambda: reports_created.delay([report.id for report in reports])) def update(self, request: Request, *args: Any, **kwargs: Any) -> Response: # DRF itself does not support upsert. @@ -90,8 +98,11 @@ def get_object_or_none(self) -> Report | None: def perform_update(self, serializer: BaseSerializer) -> None: super().perform_update(serializer) assert serializer.instance - report: Report = serializer.instance - transaction.on_commit(lambda: report_updated.delay(report.document_id)) + reports: list[Report] | Report = serializer.instance + if not isinstance(reports, list): + reports = [reports] + + transaction.on_commit(lambda: reports_updated.delay([report.id for report in reports])) def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Response: # Disallow partial updates @@ -100,4 +111,4 @@ def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Respons def perform_destroy(self, instance: Report) -> None: super().perform_destroy(instance) - transaction.on_commit(lambda: report_deleted.delay(instance.document_id)) + transaction.on_commit(lambda: reports_deleted.delay([instance.document_id])) diff --git a/radis/reports/site.py b/radis/reports/site.py index 4e168b01..5faa9673 100644 --- a/radis/reports/site.py +++ b/radis/reports/site.py @@ -1,23 +1,46 @@ -from typing import Any, Callable, Literal, NamedTuple +from typing import Any, Callable, NamedTuple from django.http import HttpRequest from .models import Report -ReportEventType = Literal["created", "updated", "deleted"] -ReportEventHandler = Callable[[ReportEventType, str], None] +ReportsCreatedHandler = Callable[[list[int]], None] -report_event_handlers: list[ReportEventHandler] = [] +reports_created_handlers: list[ReportsCreatedHandler] = [] -def register_report_handler(handler: ReportEventHandler) -> None: - """Register a report event handler. +def register_reports_created_handler(handler: ReportsCreatedHandler) -> None: + """Register a handler for when reports are created in the PostgreSQL database. - The report handler gets notified a report is created, updated, or deleted in - PostgreSQL database. It can be used to sync report documents in other - databases like Vespa. + The handler can be used to sync resp. index those reports in a search database like Vespa. """ - report_event_handlers.append(handler) + reports_created_handlers.append(handler) + + +ReportsUpdatedHandler = Callable[[list[int]], None] + +reports_updated_handlers: list[ReportsUpdatedHandler] = [] + + +def register_reports_updated_handler(handler: ReportsUpdatedHandler) -> None: + """Register a handler for when reports are updated in the PostgreSQL database. + + The handler can be used to sync resp. re-index those reports in a search database like Vespa. + """ + reports_updated_handlers.append(handler) + + +ReportsDeletedHandler = Callable[[list[str]], None] + +reports_deleted_handlers: list[ReportsDeletedHandler] = [] + + +def register_reports_deleted_handler(handler: ReportsDeletedHandler) -> None: + """Register a handler for when reports are deleted in the PostgreSQL database. + + The handler can be used to remove those reports from the index of search databases like Vespa. + """ + reports_deleted_handlers.append(handler) FetchDocument = Callable[[Report], dict[str, Any] | None] diff --git a/radis/reports/tasks.py b/radis/reports/tasks.py index 9f1eeb73..10fb1249 100644 --- a/radis/reports/tasks.py +++ b/radis/reports/tasks.py @@ -1,21 +1,21 @@ from celery import shared_task -from .site import report_event_handlers +from .site import reports_created_handlers, reports_deleted_handlers, reports_updated_handlers @shared_task -def report_created(document_id: str) -> None: - for handler in report_event_handlers: - handler("created", document_id) +def reports_created(report_ids: list[int]) -> None: + for handler in reports_created_handlers: + handler(report_ids) @shared_task -def report_updated(document_id: str) -> None: - for handler in report_event_handlers: - handler("updated", document_id) +def reports_updated(report_ids: list[int]) -> None: + for handler in reports_updated_handlers: + handler(report_ids) @shared_task -def report_deleted(document_id: str) -> None: - for handler in report_event_handlers: - handler("deleted", document_id) +def reports_deleted(document_ids: list[str]) -> None: + for handler in reports_deleted_handlers: + handler(document_ids) diff --git a/radis/vespa/apps.py b/radis/vespa/apps.py index 297e0a4d..524f6210 100644 --- a/radis/vespa/apps.py +++ b/radis/vespa/apps.py @@ -14,9 +14,10 @@ def register_app(): from radis.rag.site import RetrievalProvider, register_retrieval_provider from radis.reports.models import Report from radis.reports.site import ( - ReportEventType, register_document_fetcher, - register_report_handler, + register_reports_created_handler, + register_reports_deleted_handler, + register_reports_updated_handler, ) from radis.search.site import SearchProvider, register_search_provider from radis.vespa.providers import retrieve_bm25 @@ -24,25 +25,17 @@ def register_app(): from .providers import search_bm25, search_hybrid, search_semantic from .utils.document_utils import ( - create_document, - delete_document, + create_documents, + delete_documents, fetch_document, - update_document, + update_documents, ) - def handle_report(event_type: ReportEventType, document_id: str): - if event_type in ("created", "updated"): - report = Report.objects.get(document_id=document_id) - if event_type == "created": - create_document(document_id, report) - elif event_type == "updated": - update_document(document_id, report) - elif event_type == "deleted": - delete_document(document_id) - else: - raise ValueError(f"Invalid report event type: {event_type}") - - register_report_handler(handle_report) + register_reports_created_handler(lambda report_ids: create_documents(report_ids)) + + register_reports_updated_handler(lambda report_ids: update_documents(report_ids)) + + register_reports_deleted_handler(lambda document_ids: delete_documents(document_ids)) def fetch_vespa_document(report: Report) -> dict[str, Any]: return fetch_document(report.document_id) diff --git a/radis/vespa/utils/document_utils.py b/radis/vespa/utils/document_utils.py index 07e0ac8f..be97ca88 100644 --- a/radis/vespa/utils/document_utils.py +++ b/radis/vespa/utils/document_utils.py @@ -1,11 +1,17 @@ +import logging from datetime import date, datetime, time -from typing import Any +from typing import Any, Iterable + +from django.db.models import QuerySet +from vespa.io import VespaResponse from radis.reports.models import Report from radis.search.site import ReportDocument from ..vespa_app import REPORT_SCHEMA_NAME, vespa_app +logger = logging.getLogger(__name__) + def _dictify_report_for_vespa(report: Report) -> dict[str, Any]: """Dictify the report for Vespa. @@ -33,6 +39,12 @@ def _dictify_report_for_vespa(report: Report) -> dict[str, Any]: } +def _generate_feedable_documents(reports: QuerySet[Report]) -> Iterable[dict]: + for report in reports: + fields = _dictify_report_for_vespa(report) + yield {"id": report.document_id, "fields": fields} + + def fetch_document(document_id: str) -> dict[str, Any]: response = vespa_app.get_client().get_data(REPORT_SCHEMA_NAME, document_id) @@ -43,30 +55,54 @@ def fetch_document(document_id: str) -> dict[str, Any]: return response.get_json() -def create_document(document_id: str, report: Report) -> None: - fields = _dictify_report_for_vespa(report) - response = vespa_app.get_client().feed_data_point(REPORT_SCHEMA_NAME, document_id, fields) +def create_documents(report_ids: list[int]) -> None: + reports = Report.objects.filter(id__in=report_ids) - if response.get_status_code() != 200: - message = response.get_json() - raise Exception(f"Error while feeding document to Vespa: {message}") + def callback(response: VespaResponse, id: str): + if response.get_status_code() == 200: + logger.debug(f"Successfully fed document with id {id} to Vespa") + else: + message = response.get_json() + logger.error(f"Error while feeding document with id {id} to Vespa: {message}") + vespa_app.get_client().feed_iterable( + _generate_feedable_documents(reports), REPORT_SCHEMA_NAME, callback=callback + ) -def update_document(document_id: str, report: Report) -> None: - fields = _dictify_report_for_vespa(report) - response = vespa_app.get_client().update_data(REPORT_SCHEMA_NAME, document_id, fields) - if response.get_status_code() != 200: - message = response.get_json() - raise Exception(f"Error while updating document on Vespa: {message}") +def update_documents(report_ids: list[int]) -> None: + reports = Report.objects.filter(id__in=report_ids) + def callback(response: VespaResponse, id: str): + if response.get_status_code() == 200: + logger.debug(f"Successfully updated document with id {id} in Vespa") + pass + else: + message = response.get_json() + logger.error(f"Error while updating document with id {id} in Vespa: {message}") -def delete_document(document_id: str) -> None: - response = vespa_app.get_client().delete_data(REPORT_SCHEMA_NAME, document_id) + vespa_app.get_client().feed_iterable( + _generate_feedable_documents(reports), + REPORT_SCHEMA_NAME, + operation_type="update", + callback=callback, + ) - if response.get_status_code() != 200: - message = response.get_json() - raise Exception(f"Error while deleting document on Vespa: {message}") + +def delete_documents(document_ids: list[str]) -> None: + def callback(response: VespaResponse, id: str): + if response.get_status_code() == 200: + logger.debug(f"Successfully deleted document with id {id} in Vespa") + else: + message = response.get_json() + logger.error(f"Error while deleting document with id {id} in Vespa: {message}") + + vespa_app.get_client().feed_iterable( + [{"id": id} for id in document_ids], + REPORT_SCHEMA_NAME, + operation_type="delete", + callback=callback, + ) def document_from_vespa_response(record: dict[str, Any]) -> ReportDocument: