diff --git a/TODO.md b/TODO.md index cc3e6461..dbbd8e0b 100644 --- a/TODO.md +++ b/TODO.md @@ -2,6 +2,7 @@ ## High Priority +- urls -> pacs_link - Page titles - Filter reports by active group of user - Think about a better delete strategy diff --git a/radis/core/templatetags/core_extras.py b/radis/core/templatetags/core_extras.py index b6fee824..83a676da 100644 --- a/radis/core/templatetags/core_extras.py +++ b/radis/core/templatetags/core_extras.py @@ -46,9 +46,7 @@ def analysis_task_status_css_class(status: AnalysisTask.Status) -> str: return css_classes[status] -# TODO: Resolve reference names from another source in the context -# Context must be set in the view -@register.simple_tag(takes_context=True) -def url_abbreviation(context: dict, url: str): +@register.simple_tag +def url_abbreviation(url: str): abbr = re.sub(r"^(https?://)?(www.)?", "", url) return abbr[:5] diff --git a/radis/rag/migrations/0005_ragjob_group.py b/radis/rag/migrations/0005_ragjob_group.py new file mode 100644 index 00000000..1adfcfad --- /dev/null +++ b/radis/rag/migrations/0005_ragjob_group.py @@ -0,0 +1,21 @@ +# Generated by Django 5.0.4 on 2024-04-10 13:53 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('auth', '0012_alter_user_first_name_max_length'), + ('rag', '0004_ragjob_language'), + ] + + operations = [ + migrations.AddField( + model_name='ragjob', + name='group', + field=models.ForeignKey(default=1, on_delete=django.db.models.deletion.CASCADE, to='auth.group'), + preserve_default=False, + ), + ] diff --git a/radis/rag/models.py b/radis/rag/models.py index 12529873..e0dff4b0 100644 --- a/radis/rag/models.py +++ b/radis/rag/models.py @@ -2,6 +2,7 @@ from celery import current_app from django.conf import settings +from django.contrib.auth.models import Group from django.contrib.postgres.fields import ArrayField from django.db import models from django.urls import reverse @@ -47,8 +48,9 @@ class RagJob(AnalysisJob): provider = models.CharField( max_length=100, choices=lazy(get_retrieval_providers, tuple)(), default=get_default_provider ) + group = models.ForeignKey(Group, on_delete=models.CASCADE) query = models.CharField(max_length=200) - language = models.CharField(max_length=10, blank=True) + language = models.CharField(max_length=10, blank=True) # TODO: foreign key modalities = ArrayField(models.CharField(max_length=16)) study_date_from = models.DateField(null=True, blank=True) study_date_till = models.DateField(null=True, blank=True) diff --git a/radis/rag/tasks.py b/radis/rag/tasks.py index 6254249e..22077080 100644 --- a/radis/rag/tasks.py +++ b/radis/rag/tasks.py @@ -131,6 +131,7 @@ def collect_tasks(self, job: RagJob) -> Iterator[RagTask]: retrieval_provider = retrieval_providers[provider] search = Search( + group=job.group.pk, query=job.query, offset=0, limit=retrieval_provider.max_results, diff --git a/radis/rag/views.py b/radis/rag/views.py index 3dd9d6d1..812a892b 100644 --- a/radis/rag/views.py +++ b/radis/rag/views.py @@ -1,7 +1,11 @@ from typing import cast from django.conf import settings -from django.contrib.auth.mixins import LoginRequiredMixin, PermissionRequiredMixin +from django.contrib.auth.mixins import ( + LoginRequiredMixin, + PermissionRequiredMixin, + UserPassesTestMixin, +) from django.core.exceptions import SuspiciousOperation from django.db import transaction from django.forms import BaseInlineFormSet, ModelForm @@ -55,14 +59,20 @@ class RagJobListView(RagLockedMixin, AnalysisJobListView): template_name = "rag/rag_job_list.html" -class RagJobWizardView(LoginRequiredMixin, PermissionRequiredMixin, SessionWizardView): +class RagJobWizardView( + LoginRequiredMixin, PermissionRequiredMixin, UserPassesTestMixin, SessionWizardView +): SEARCH_STEP = "0" QUESTIONS_STEP = "1" form_list = [SearchForm, QuestionFormSet] permission_required = "rag.add_ragjob" + permission_denied_message = "You must be logged in and have an active group" request: AuthenticatedHttpRequest + def test_func(self) -> bool | None: + return self.request.user.active_group is not None + def get_form_kwargs(self, step=None): kwargs = super().get_form_kwargs(step) if step == RagJobWizardView.SEARCH_STEP: @@ -109,7 +119,11 @@ def done(self, form_objs: list[ModelForm | BaseInlineFormSet], **kwargs): return redirect(job) def _estimate_retrieval_count(self, data: dict) -> int: + active_group = self.request.user.active_group + assert active_group + search = Search( + group=active_group.pk, query=data["query"], offset=0, limit=0, diff --git a/radis/reports/templates/reports/_report_buttons_panel.html b/radis/reports/templates/reports/_report_buttons_panel.html index 9113c897..63c9503f 100644 --- a/radis/reports/templates/reports/_report_buttons_panel.html +++ b/radis/reports/templates/reports/_report_buttons_panel.html @@ -1,9 +1,12 @@ {% load access_item bootstrap_icon from common_extras %} +{% load can_view_report from reports_extras %}
- {% if not no_view_button %} + {% if not hide_view_button %} + {% can_view_report report as viewable %} + class="btn btn-secondary btn-sm {% if not viewable %}disabled{% endif %}" + {% if not viewable %}aria-disabled="true"{% endif %}> {% bootstrap_icon "box-arrow-in-down-right" %} Details diff --git a/radis/reports/templates/reports/report_detail.html b/radis/reports/templates/reports/report_detail.html index 8b507bae..9ea0052d 100644 --- a/radis/reports/templates/reports/report_detail.html +++ b/radis/reports/templates/reports/report_detail.html @@ -31,7 +31,7 @@

Report Details

-
{% include "reports/_report_buttons_panel.html" with no_view_button=True %}
+
{% include "reports/_report_buttons_panel.html" with hide_view_button=True %}
Report Text
diff --git a/radis/reports/templatetags/__init__.py b/radis/reports/templatetags/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/radis/reports/templatetags/reports_extras.py b/radis/reports/templatetags/reports_extras.py new file mode 100644 index 00000000..e74cef2d --- /dev/null +++ b/radis/reports/templatetags/reports_extras.py @@ -0,0 +1,18 @@ +from typing import Any, cast + +from django.template import Library + +from adit_radis_shared.accounts.models import User + +from ..models import Report + +register = Library() + + +@register.simple_tag(takes_context=True) +def can_view_report(context: dict[str, Any], report: Report) -> bool: + user = cast(User, context["request"].user) + active_group = user.active_group + if not active_group: + return False + return report.groups.filter(pk=active_group.pk).exists() diff --git a/radis/reports/views.py b/radis/reports/views.py index f635217f..6dbc7cc0 100644 --- a/radis/reports/views.py +++ b/radis/reports/views.py @@ -1,13 +1,26 @@ -from django.contrib.auth.mixins import LoginRequiredMixin +from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin +from django.db.models import QuerySet from django.views.generic.detail import DetailView +from adit_radis_shared.common.types import AuthenticatedHttpRequest + from .models import Report -class ReportDetailView(LoginRequiredMixin, DetailView): +class ReportDetailView(LoginRequiredMixin, UserPassesTestMixin, DetailView): model = Report template_name = "reports/report_detail.html" context_object_name = "report" + permission_denied_message = "You must be logged in and have an active group" + request: AuthenticatedHttpRequest + + def test_func(self) -> bool | None: + return self.request.user.active_group is not None + + def get_queryset(self) -> QuerySet[Report]: + active_group = self.request.user.active_group + assert active_group + return super().get_queryset().filter(groups=active_group) class ReportBodyView(ReportDetailView): diff --git a/radis/search/site.py b/radis/search/site.py index 24d2036f..4c62145a 100644 --- a/radis/search/site.py +++ b/radis/search/site.py @@ -49,12 +49,14 @@ class Search(NamedTuple): should return the most accurate total count it can calculate. Attributes: + - group (int): The ID of the group to search. - query (str): The query to search. - offset (int): The offset of the search results. - limit (int): The limit of the search results. - filters (SearchFilters): The filters to apply to the search. """ + group: int query: str offset: int = 0 limit: int = 10 diff --git a/radis/search/views.py b/radis/search/views.py index a513cd0a..b11f51ea 100644 --- a/radis/search/views.py +++ b/radis/search/views.py @@ -1,6 +1,6 @@ from typing import Any -from django.contrib.auth.mixins import LoginRequiredMixin +from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin from django.core.exceptions import ValidationError from django.core.paginator import Paginator from django.http import Http404, HttpRequest @@ -14,7 +14,13 @@ from .site import Search, SearchFilters, search_providers -class SearchView(LoginRequiredMixin, View): +class SearchView(LoginRequiredMixin, UserPassesTestMixin, View): + permission_denied_message = "You must be logged in and have an active group" + request: AuthenticatedHttpRequest + + def test_func(self) -> bool | None: + return self.request.user.active_group is not None + def get(self, request: AuthenticatedHttpRequest, *args, **kwargs): form = SearchForm(request.GET) context: dict[str, Any] = {"form": form} @@ -47,8 +53,12 @@ def get(self, request: AuthenticatedHttpRequest, *args, **kwargs): offset = (page_number - 1) * page_size context["offset"] = offset + active_group = self.request.user.active_group + assert active_group + if query: search = Search( + group=active_group.pk, query=query, offset=offset, limit=page_size, diff --git a/radis/vespa/providers.py b/radis/vespa/providers.py index d0373c63..63d938c4 100644 --- a/radis/vespa/providers.py +++ b/radis/vespa/providers.py @@ -34,6 +34,7 @@ def _execute_query(params: dict[str, Any]) -> VespaQueryResponse: def search_bm25(search: Search) -> SearchResult: yql = "select * from sources * where userQuery()" + yql += f" and groups = {search.group}" filters = build_yql_filter(search.filters) if filters: yql += f" {filters}" @@ -60,6 +61,7 @@ def search_bm25(search: Search) -> SearchResult: def search_semantic(search: Search) -> SearchResult: yql = "select * from sources * where userQuery()" + yql += f" and groups = {search.group}" filters = build_yql_filter(search.filters) if filters: yql += f" {filters}" @@ -88,6 +90,7 @@ def search_semantic(search: Search) -> SearchResult: # https://pyvespa.readthedocs.io/en/latest/getting-started-pyvespa.html#Hybrid-search-with-the-OR-query-operator def search_hybrid(search: Search) -> SearchResult: yql = "select * from sources * where userQuery()" + yql += f" and groups = {search.group}" filters = build_yql_filter(search.filters) if filters: yql += f" {filters}" @@ -115,6 +118,7 @@ def search_hybrid(search: Search) -> SearchResult: def retrieve_bm25(search: Search) -> RetrievalResult: yql = "select * from sources * where userQuery()" + yql += f" and groups = {search.group}" filters = build_yql_filter(search.filters) if filters: yql += f" {filters}"