diff --git a/TODO.md b/TODO.md
index cc3e6461..dbbd8e0b 100644
--- a/TODO.md
+++ b/TODO.md
@@ -2,6 +2,7 @@
## High Priority
+- urls -> pacs_link
- Page titles
- Filter reports by active group of user
- Think about a better delete strategy
diff --git a/radis/core/templatetags/core_extras.py b/radis/core/templatetags/core_extras.py
index b6fee824..83a676da 100644
--- a/radis/core/templatetags/core_extras.py
+++ b/radis/core/templatetags/core_extras.py
@@ -46,9 +46,7 @@ def analysis_task_status_css_class(status: AnalysisTask.Status) -> str:
return css_classes[status]
-# TODO: Resolve reference names from another source in the context
-# Context must be set in the view
-@register.simple_tag(takes_context=True)
-def url_abbreviation(context: dict, url: str):
+@register.simple_tag
+def url_abbreviation(url: str):
abbr = re.sub(r"^(https?://)?(www.)?", "", url)
return abbr[:5]
diff --git a/radis/rag/migrations/0005_ragjob_group.py b/radis/rag/migrations/0005_ragjob_group.py
new file mode 100644
index 00000000..1adfcfad
--- /dev/null
+++ b/radis/rag/migrations/0005_ragjob_group.py
@@ -0,0 +1,21 @@
+# Generated by Django 5.0.4 on 2024-04-10 13:53
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('auth', '0012_alter_user_first_name_max_length'),
+ ('rag', '0004_ragjob_language'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='ragjob',
+ name='group',
+ field=models.ForeignKey(default=1, on_delete=django.db.models.deletion.CASCADE, to='auth.group'),
+ preserve_default=False,
+ ),
+ ]
diff --git a/radis/rag/models.py b/radis/rag/models.py
index 12529873..e0dff4b0 100644
--- a/radis/rag/models.py
+++ b/radis/rag/models.py
@@ -2,6 +2,7 @@
from celery import current_app
from django.conf import settings
+from django.contrib.auth.models import Group
from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.urls import reverse
@@ -47,8 +48,9 @@ class RagJob(AnalysisJob):
provider = models.CharField(
max_length=100, choices=lazy(get_retrieval_providers, tuple)(), default=get_default_provider
)
+ group = models.ForeignKey(Group, on_delete=models.CASCADE)
query = models.CharField(max_length=200)
- language = models.CharField(max_length=10, blank=True)
+ language = models.CharField(max_length=10, blank=True) # TODO: foreign key
modalities = ArrayField(models.CharField(max_length=16))
study_date_from = models.DateField(null=True, blank=True)
study_date_till = models.DateField(null=True, blank=True)
diff --git a/radis/rag/tasks.py b/radis/rag/tasks.py
index 6254249e..22077080 100644
--- a/radis/rag/tasks.py
+++ b/radis/rag/tasks.py
@@ -131,6 +131,7 @@ def collect_tasks(self, job: RagJob) -> Iterator[RagTask]:
retrieval_provider = retrieval_providers[provider]
search = Search(
+ group=job.group.pk,
query=job.query,
offset=0,
limit=retrieval_provider.max_results,
diff --git a/radis/rag/views.py b/radis/rag/views.py
index 3dd9d6d1..812a892b 100644
--- a/radis/rag/views.py
+++ b/radis/rag/views.py
@@ -1,7 +1,11 @@
from typing import cast
from django.conf import settings
-from django.contrib.auth.mixins import LoginRequiredMixin, PermissionRequiredMixin
+from django.contrib.auth.mixins import (
+ LoginRequiredMixin,
+ PermissionRequiredMixin,
+ UserPassesTestMixin,
+)
from django.core.exceptions import SuspiciousOperation
from django.db import transaction
from django.forms import BaseInlineFormSet, ModelForm
@@ -55,14 +59,20 @@ class RagJobListView(RagLockedMixin, AnalysisJobListView):
template_name = "rag/rag_job_list.html"
-class RagJobWizardView(LoginRequiredMixin, PermissionRequiredMixin, SessionWizardView):
+class RagJobWizardView(
+ LoginRequiredMixin, PermissionRequiredMixin, UserPassesTestMixin, SessionWizardView
+):
SEARCH_STEP = "0"
QUESTIONS_STEP = "1"
form_list = [SearchForm, QuestionFormSet]
permission_required = "rag.add_ragjob"
+ permission_denied_message = "You must be logged in and have an active group"
request: AuthenticatedHttpRequest
+ def test_func(self) -> bool | None:
+ return self.request.user.active_group is not None
+
def get_form_kwargs(self, step=None):
kwargs = super().get_form_kwargs(step)
if step == RagJobWizardView.SEARCH_STEP:
@@ -109,7 +119,11 @@ def done(self, form_objs: list[ModelForm | BaseInlineFormSet], **kwargs):
return redirect(job)
def _estimate_retrieval_count(self, data: dict) -> int:
+ active_group = self.request.user.active_group
+ assert active_group
+
search = Search(
+ group=active_group.pk,
query=data["query"],
offset=0,
limit=0,
diff --git a/radis/reports/templates/reports/_report_buttons_panel.html b/radis/reports/templates/reports/_report_buttons_panel.html
index 9113c897..63c9503f 100644
--- a/radis/reports/templates/reports/_report_buttons_panel.html
+++ b/radis/reports/templates/reports/_report_buttons_panel.html
@@ -1,9 +1,12 @@
{% load access_item bootstrap_icon from common_extras %}
+{% load can_view_report from reports_extras %}
- {% if not no_view_button %}
+ {% if not hide_view_button %}
+ {% can_view_report report as viewable %}
+ class="btn btn-secondary btn-sm {% if not viewable %}disabled{% endif %}"
+ {% if not viewable %}aria-disabled="true"{% endif %}>
{% bootstrap_icon "box-arrow-in-down-right" %}
Details
diff --git a/radis/reports/templates/reports/report_detail.html b/radis/reports/templates/reports/report_detail.html
index 8b507bae..9ea0052d 100644
--- a/radis/reports/templates/reports/report_detail.html
+++ b/radis/reports/templates/reports/report_detail.html
@@ -31,7 +31,7 @@
Report Details
-
{% include "reports/_report_buttons_panel.html" with no_view_button=True %}
+
{% include "reports/_report_buttons_panel.html" with hide_view_button=True %}
Report Text
diff --git a/radis/reports/templatetags/__init__.py b/radis/reports/templatetags/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/radis/reports/templatetags/reports_extras.py b/radis/reports/templatetags/reports_extras.py
new file mode 100644
index 00000000..e74cef2d
--- /dev/null
+++ b/radis/reports/templatetags/reports_extras.py
@@ -0,0 +1,18 @@
+from typing import Any, cast
+
+from django.template import Library
+
+from adit_radis_shared.accounts.models import User
+
+from ..models import Report
+
+register = Library()
+
+
+@register.simple_tag(takes_context=True)
+def can_view_report(context: dict[str, Any], report: Report) -> bool:
+ user = cast(User, context["request"].user)
+ active_group = user.active_group
+ if not active_group:
+ return False
+ return report.groups.filter(pk=active_group.pk).exists()
diff --git a/radis/reports/views.py b/radis/reports/views.py
index f635217f..6dbc7cc0 100644
--- a/radis/reports/views.py
+++ b/radis/reports/views.py
@@ -1,13 +1,26 @@
-from django.contrib.auth.mixins import LoginRequiredMixin
+from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin
+from django.db.models import QuerySet
from django.views.generic.detail import DetailView
+from adit_radis_shared.common.types import AuthenticatedHttpRequest
+
from .models import Report
-class ReportDetailView(LoginRequiredMixin, DetailView):
+class ReportDetailView(LoginRequiredMixin, UserPassesTestMixin, DetailView):
model = Report
template_name = "reports/report_detail.html"
context_object_name = "report"
+ permission_denied_message = "You must be logged in and have an active group"
+ request: AuthenticatedHttpRequest
+
+ def test_func(self) -> bool | None:
+ return self.request.user.active_group is not None
+
+ def get_queryset(self) -> QuerySet[Report]:
+ active_group = self.request.user.active_group
+ assert active_group
+ return super().get_queryset().filter(groups=active_group)
class ReportBodyView(ReportDetailView):
diff --git a/radis/search/site.py b/radis/search/site.py
index 24d2036f..4c62145a 100644
--- a/radis/search/site.py
+++ b/radis/search/site.py
@@ -49,12 +49,14 @@ class Search(NamedTuple):
should return the most accurate total count it can calculate.
Attributes:
+ - group (int): The ID of the group to search.
- query (str): The query to search.
- offset (int): The offset of the search results.
- limit (int): The limit of the search results.
- filters (SearchFilters): The filters to apply to the search.
"""
+ group: int
query: str
offset: int = 0
limit: int = 10
diff --git a/radis/search/views.py b/radis/search/views.py
index a513cd0a..b11f51ea 100644
--- a/radis/search/views.py
+++ b/radis/search/views.py
@@ -1,6 +1,6 @@
from typing import Any
-from django.contrib.auth.mixins import LoginRequiredMixin
+from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin
from django.core.exceptions import ValidationError
from django.core.paginator import Paginator
from django.http import Http404, HttpRequest
@@ -14,7 +14,13 @@
from .site import Search, SearchFilters, search_providers
-class SearchView(LoginRequiredMixin, View):
+class SearchView(LoginRequiredMixin, UserPassesTestMixin, View):
+ permission_denied_message = "You must be logged in and have an active group"
+ request: AuthenticatedHttpRequest
+
+ def test_func(self) -> bool | None:
+ return self.request.user.active_group is not None
+
def get(self, request: AuthenticatedHttpRequest, *args, **kwargs):
form = SearchForm(request.GET)
context: dict[str, Any] = {"form": form}
@@ -47,8 +53,12 @@ def get(self, request: AuthenticatedHttpRequest, *args, **kwargs):
offset = (page_number - 1) * page_size
context["offset"] = offset
+ active_group = self.request.user.active_group
+ assert active_group
+
if query:
search = Search(
+ group=active_group.pk,
query=query,
offset=offset,
limit=page_size,
diff --git a/radis/vespa/providers.py b/radis/vespa/providers.py
index d0373c63..63d938c4 100644
--- a/radis/vespa/providers.py
+++ b/radis/vespa/providers.py
@@ -34,6 +34,7 @@ def _execute_query(params: dict[str, Any]) -> VespaQueryResponse:
def search_bm25(search: Search) -> SearchResult:
yql = "select * from sources * where userQuery()"
+ yql += f" and groups = {search.group}"
filters = build_yql_filter(search.filters)
if filters:
yql += f" {filters}"
@@ -60,6 +61,7 @@ def search_bm25(search: Search) -> SearchResult:
def search_semantic(search: Search) -> SearchResult:
yql = "select * from sources * where userQuery()"
+ yql += f" and groups = {search.group}"
filters = build_yql_filter(search.filters)
if filters:
yql += f" {filters}"
@@ -88,6 +90,7 @@ def search_semantic(search: Search) -> SearchResult:
# https://pyvespa.readthedocs.io/en/latest/getting-started-pyvespa.html#Hybrid-search-with-the-OR-query-operator
def search_hybrid(search: Search) -> SearchResult:
yql = "select * from sources * where userQuery()"
+ yql += f" and groups = {search.group}"
filters = build_yql_filter(search.filters)
if filters:
yql += f" {filters}"
@@ -115,6 +118,7 @@ def search_hybrid(search: Search) -> SearchResult:
def retrieve_bm25(search: Search) -> RetrievalResult:
yql = "select * from sources * where userQuery()"
+ yql += f" and groups = {search.group}"
filters = build_yql_filter(search.filters)
if filters:
yql += f" {filters}"