Skip to content

Commit

Permalink
Support batch create, update, (delete) of reports
Browse files Browse the repository at this point in the history
  • Loading branch information
medihack committed Mar 19, 2024
1 parent afba48e commit 5c23feb
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 69 deletions.
4 changes: 4 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## High Priority

- Use own worker for indexing
- Thinks about a better delete strategy
- Delete document_id from Vespa schema and extract it from the returned id
- <https://docs.vespa.ai/en/documents.html#document-ids>
- Check if for RAG ranking should be turned off for performance improvements (and using some fixed sort order)
- Some present provider.max_results to the user somehow, especially important if the query results (step 1) is larger
- task control panel
Expand Down
17 changes: 10 additions & 7 deletions radis/core/management/commands/populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from radis.accounts.models import User
from radis.reports.factories import ReportFactory
from radis.reports.models import Report
from radis.reports.site import report_event_handlers
from radis.reports.site import reports_created_handlers
from radis.token_authentication.factories import TokenFactory
from radis.token_authentication.models import FRACTION_LENGTH
from radis.token_authentication.utils.crypto import hash_token
Expand All @@ -24,12 +24,11 @@
fake = Faker()


def feed_report(body: str, language: Literal["en", "de"]):
def create_report(body: str, language: Literal["en", "de"]):
report = ReportFactory.create(language=language, body=body)
groups = fake.random_elements(elements=list(Group.objects.all()), unique=True)
report.groups.set(groups)
for handler in report_event_handlers:
handler("created", report.document_id)
return report


def feed_reports(language: Literal["en", "de"]):
Expand All @@ -42,10 +41,14 @@ def feed_reports(language: Literal["en", "de"]):

samples_path = Path(settings.BASE_DIR / "samples" / sample_file)
with open(samples_path, "r") as f:
reports = json.load(f)
report_bodies = json.load(f)

for report in reports:
feed_report(report, language)
reports: list[Report] = []
for report_body in report_bodies:
reports.append(create_report(report_body, language))

for handler in reports_created_handlers:
handler([report.id for report in reports])


def create_admin() -> User:
Expand Down
23 changes: 17 additions & 6 deletions radis/reports/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rest_framework.response import Response
from rest_framework.serializers import BaseSerializer

from radis.reports.tasks import report_created, report_deleted, report_updated
from radis.reports.tasks import reports_created, reports_deleted, reports_updated

from ..models import Report
from ..site import document_fetchers
Expand All @@ -33,6 +33,11 @@ class ReportViewSet(
lookup_field = "document_id"
permission_classes = [IsAdminUser]

def get_serializer(self, *args: Any, **kwargs: Any) -> BaseSerializer:
if isinstance(kwargs.get("data", {}), list):
kwargs["many"] = True
return super().get_serializer(*args, **kwargs)

def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response:
"""Retrieve a single Report.
Expand All @@ -57,8 +62,11 @@ def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response:
def perform_create(self, serializer: BaseSerializer) -> None:
super().perform_create(serializer)
assert serializer.instance
report: Report = serializer.instance
transaction.on_commit(lambda: report_created.delay(report.document_id))
reports: list[Report] | Report = serializer.instance
if not isinstance(reports, list):
reports = [reports]

transaction.on_commit(lambda: reports_created.delay([report.id for report in reports]))

def update(self, request: Request, *args: Any, **kwargs: Any) -> Response:
# DRF itself does not support upsert.
Expand Down Expand Up @@ -90,8 +98,11 @@ def get_object_or_none(self) -> Report | None:
def perform_update(self, serializer: BaseSerializer) -> None:
super().perform_update(serializer)
assert serializer.instance
report: Report = serializer.instance
transaction.on_commit(lambda: report_updated.delay(report.document_id))
reports: list[Report] | Report = serializer.instance
if not isinstance(reports, list):
reports = [reports]

transaction.on_commit(lambda: reports_updated.delay([report.id for report in reports]))

def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Response:
# Disallow partial updates
Expand All @@ -100,4 +111,4 @@ def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Respons

def perform_destroy(self, instance: Report) -> None:
super().perform_destroy(instance)
transaction.on_commit(lambda: report_deleted.delay(instance.document_id))
transaction.on_commit(lambda: reports_deleted.delay([instance.document_id]))
43 changes: 33 additions & 10 deletions radis/reports/site.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,46 @@
from typing import Any, Callable, Literal, NamedTuple
from typing import Any, Callable, NamedTuple

from django.http import HttpRequest

from .models import Report

ReportEventType = Literal["created", "updated", "deleted"]
ReportEventHandler = Callable[[ReportEventType, str], None]
ReportsCreatedHandler = Callable[[list[int]], None]

report_event_handlers: list[ReportEventHandler] = []
reports_created_handlers: list[ReportsCreatedHandler] = []


def register_report_handler(handler: ReportEventHandler) -> None:
"""Register a report event handler.
def register_reports_created_handler(handler: ReportsCreatedHandler) -> None:
"""Register a handler for when reports are created in the PostgreSQL database.
The report handler gets notified a report is created, updated, or deleted in
PostgreSQL database. It can be used to sync report documents in other
databases like Vespa.
The handler can be used to sync resp. index those reports in a search database like Vespa.
"""
report_event_handlers.append(handler)
reports_created_handlers.append(handler)


ReportsUpdatedHandler = Callable[[list[int]], None]

reports_updated_handlers: list[ReportsUpdatedHandler] = []


def register_reports_updated_handler(handler: ReportsUpdatedHandler) -> None:
"""Register a handler for when reports are updated in the PostgreSQL database.
The handler can be used to sync resp. re-index those reports in a search database like Vespa.
"""
reports_updated_handlers.append(handler)


ReportsDeletedHandler = Callable[[list[str]], None]

reports_deleted_handlers: list[ReportsDeletedHandler] = []


def register_reports_deleted_handler(handler: ReportsDeletedHandler) -> None:
"""Register a handler for when reports are deleted in the PostgreSQL database.
The handler can be used to remove those reports from the index of search databases like Vespa.
"""
reports_deleted_handlers.append(handler)


FetchDocument = Callable[[Report], dict[str, Any] | None]
Expand Down
20 changes: 10 additions & 10 deletions radis/reports/tasks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from celery import shared_task

from .site import report_event_handlers
from .site import reports_created_handlers, reports_deleted_handlers, reports_updated_handlers


@shared_task
def report_created(document_id: str) -> None:
for handler in report_event_handlers:
handler("created", document_id)
def reports_created(report_ids: list[int]) -> None:
for handler in reports_created_handlers:
handler(report_ids)


@shared_task
def report_updated(document_id: str) -> None:
for handler in report_event_handlers:
handler("updated", document_id)
def reports_updated(report_ids: list[int]) -> None:
for handler in reports_updated_handlers:
handler(report_ids)


@shared_task
def report_deleted(document_id: str) -> None:
for handler in report_event_handlers:
handler("deleted", document_id)
def reports_deleted(document_ids: list[str]) -> None:
for handler in reports_deleted_handlers:
handler(document_ids)
29 changes: 11 additions & 18 deletions radis/vespa/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,28 @@ def register_app():
from radis.rag.site import RetrievalProvider, register_retrieval_provider
from radis.reports.models import Report
from radis.reports.site import (
ReportEventType,
register_document_fetcher,
register_report_handler,
register_reports_created_handler,
register_reports_deleted_handler,
register_reports_updated_handler,
)
from radis.search.site import SearchProvider, register_search_provider
from radis.vespa.providers import retrieve_bm25
from radis.vespa.vespa_app import MAX_RETRIEVAL_HITS, MAX_SEARCH_HITS

from .providers import search_bm25, search_hybrid, search_semantic
from .utils.document_utils import (
create_document,
delete_document,
create_documents,
delete_documents,
fetch_document,
update_document,
update_documents,
)

def handle_report(event_type: ReportEventType, document_id: str):
if event_type in ("created", "updated"):
report = Report.objects.get(document_id=document_id)
if event_type == "created":
create_document(document_id, report)
elif event_type == "updated":
update_document(document_id, report)
elif event_type == "deleted":
delete_document(document_id)
else:
raise ValueError(f"Invalid report event type: {event_type}")

register_report_handler(handle_report)
register_reports_created_handler(lambda report_ids: create_documents(report_ids))

register_reports_updated_handler(lambda report_ids: update_documents(report_ids))

register_reports_deleted_handler(lambda document_ids: delete_documents(document_ids))

def fetch_vespa_document(report: Report) -> dict[str, Any]:
return fetch_document(report.document_id)
Expand Down
72 changes: 54 additions & 18 deletions radis/vespa/utils/document_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
from datetime import date, datetime, time
from typing import Any
from typing import Any, Iterable

from django.db.models import QuerySet
from vespa.io import VespaResponse

from radis.reports.models import Report
from radis.search.site import ReportDocument

from ..vespa_app import REPORT_SCHEMA_NAME, vespa_app

logger = logging.getLogger(__name__)


def _dictify_report_for_vespa(report: Report) -> dict[str, Any]:
"""Dictify the report for Vespa.
Expand Down Expand Up @@ -33,6 +39,12 @@ def _dictify_report_for_vespa(report: Report) -> dict[str, Any]:
}


def _generate_feedable_documents(reports: QuerySet[Report]) -> Iterable[dict]:
for report in reports:
fields = _dictify_report_for_vespa(report)
yield {"id": report.document_id, "fields": fields}


def fetch_document(document_id: str) -> dict[str, Any]:
response = vespa_app.get_client().get_data(REPORT_SCHEMA_NAME, document_id)

Expand All @@ -43,30 +55,54 @@ def fetch_document(document_id: str) -> dict[str, Any]:
return response.get_json()


def create_document(document_id: str, report: Report) -> None:
fields = _dictify_report_for_vespa(report)
response = vespa_app.get_client().feed_data_point(REPORT_SCHEMA_NAME, document_id, fields)
def create_documents(report_ids: list[int]) -> None:
reports = Report.objects.filter(id__in=report_ids)

if response.get_status_code() != 200:
message = response.get_json()
raise Exception(f"Error while feeding document to Vespa: {message}")
def callback(response: VespaResponse, id: str):
if response.get_status_code() == 200:
logger.debug(f"Successfully fed document with id {id} to Vespa")
else:
message = response.get_json()
logger.error(f"Error while feeding document with id {id} to Vespa: {message}")

vespa_app.get_client().feed_iterable(
_generate_feedable_documents(reports), REPORT_SCHEMA_NAME, callback=callback
)

def update_document(document_id: str, report: Report) -> None:
fields = _dictify_report_for_vespa(report)
response = vespa_app.get_client().update_data(REPORT_SCHEMA_NAME, document_id, fields)

if response.get_status_code() != 200:
message = response.get_json()
raise Exception(f"Error while updating document on Vespa: {message}")
def update_documents(report_ids: list[int]) -> None:
reports = Report.objects.filter(id__in=report_ids)

def callback(response: VespaResponse, id: str):
if response.get_status_code() == 200:
logger.debug(f"Successfully updated document with id {id} in Vespa")
pass
else:
message = response.get_json()
logger.error(f"Error while updating document with id {id} in Vespa: {message}")

def delete_document(document_id: str) -> None:
response = vespa_app.get_client().delete_data(REPORT_SCHEMA_NAME, document_id)
vespa_app.get_client().feed_iterable(
_generate_feedable_documents(reports),
REPORT_SCHEMA_NAME,
operation_type="update",
callback=callback,
)

if response.get_status_code() != 200:
message = response.get_json()
raise Exception(f"Error while deleting document on Vespa: {message}")

def delete_documents(document_ids: list[str]) -> None:
def callback(response: VespaResponse, id: str):
if response.get_status_code() == 200:
logger.debug(f"Successfully deleted document with id {id} in Vespa")
else:
message = response.get_json()
logger.error(f"Error while deleting document with id {id} in Vespa: {message}")

vespa_app.get_client().feed_iterable(
[{"id": id} for id in document_ids],
REPORT_SCHEMA_NAME,
operation_type="delete",
callback=callback,
)


def document_from_vespa_response(record: dict[str, Any]) -> ReportDocument:
Expand Down

0 comments on commit 5c23feb

Please sign in to comment.