Skip to content

Commit

Permalink
Allow to access reports only by users with active group
Browse files Browse the repository at this point in the history
  • Loading branch information
medihack committed Apr 10, 2024
1 parent 67943f1 commit e268776
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 14 deletions.
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## High Priority

- urls -> pacs_link
- Page titles
- Filter reports by active group of user
- Think about a better delete strategy
Expand Down
6 changes: 2 additions & 4 deletions radis/core/templatetags/core_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def analysis_task_status_css_class(status: AnalysisTask.Status) -> str:
return css_classes[status]


# TODO: Resolve reference names from another source in the context
# Context must be set in the view
@register.simple_tag(takes_context=True)
def url_abbreviation(context: dict, url: str):
@register.simple_tag
def url_abbreviation(url: str):
abbr = re.sub(r"^(https?://)?(www.)?", "", url)
return abbr[:5]
21 changes: 21 additions & 0 deletions radis/rag/migrations/0005_ragjob_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Generated by Django 5.0.4 on 2024-04-10 13:53

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('auth', '0012_alter_user_first_name_max_length'),
('rag', '0004_ragjob_language'),
]

operations = [
migrations.AddField(
model_name='ragjob',
name='group',
field=models.ForeignKey(default=1, on_delete=django.db.models.deletion.CASCADE, to='auth.group'),
preserve_default=False,
),
]
4 changes: 3 additions & 1 deletion radis/rag/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from celery import current_app
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.urls import reverse
Expand Down Expand Up @@ -47,8 +48,9 @@ class RagJob(AnalysisJob):
provider = models.CharField(
max_length=100, choices=lazy(get_retrieval_providers, tuple)(), default=get_default_provider
)
group = models.ForeignKey(Group, on_delete=models.CASCADE)
query = models.CharField(max_length=200)
language = models.CharField(max_length=10, blank=True)
language = models.CharField(max_length=10, blank=True) # TODO: foreign key
modalities = ArrayField(models.CharField(max_length=16))
study_date_from = models.DateField(null=True, blank=True)
study_date_till = models.DateField(null=True, blank=True)
Expand Down
1 change: 1 addition & 0 deletions radis/rag/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def collect_tasks(self, job: RagJob) -> Iterator[RagTask]:
retrieval_provider = retrieval_providers[provider]

search = Search(
group=job.group.pk,
query=job.query,
offset=0,
limit=retrieval_provider.max_results,
Expand Down
18 changes: 16 additions & 2 deletions radis/rag/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import cast

from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin, PermissionRequiredMixin
from django.contrib.auth.mixins import (
LoginRequiredMixin,
PermissionRequiredMixin,
UserPassesTestMixin,
)
from django.core.exceptions import SuspiciousOperation
from django.db import transaction
from django.forms import BaseInlineFormSet, ModelForm
Expand Down Expand Up @@ -55,14 +59,20 @@ class RagJobListView(RagLockedMixin, AnalysisJobListView):
template_name = "rag/rag_job_list.html"


class RagJobWizardView(LoginRequiredMixin, PermissionRequiredMixin, SessionWizardView):
class RagJobWizardView(
LoginRequiredMixin, PermissionRequiredMixin, UserPassesTestMixin, SessionWizardView
):
SEARCH_STEP = "0"
QUESTIONS_STEP = "1"

form_list = [SearchForm, QuestionFormSet]
permission_required = "rag.add_ragjob"
permission_denied_message = "You must be logged in and have an active group"
request: AuthenticatedHttpRequest

def test_func(self) -> bool | None:
return self.request.user.active_group is not None

def get_form_kwargs(self, step=None):
kwargs = super().get_form_kwargs(step)
if step == RagJobWizardView.SEARCH_STEP:
Expand Down Expand Up @@ -109,7 +119,11 @@ def done(self, form_objs: list[ModelForm | BaseInlineFormSet], **kwargs):
return redirect(job)

def _estimate_retrieval_count(self, data: dict) -> int:
active_group = self.request.user.active_group
assert active_group

search = Search(
group=active_group.pk,
query=data["query"],
offset=0,
limit=0,
Expand Down
7 changes: 5 additions & 2 deletions radis/reports/templates/reports/_report_buttons_panel.html
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
{% load access_item bootstrap_icon from common_extras %}
{% load can_view_report from reports_extras %}
<div class="mt-2 d-flex justify-content-between">
<div class="d-flex gap-2">
{% if not no_view_button %}
{% if not hide_view_button %}
{% can_view_report report as viewable %}
<a href="{% url 'report_detail' report.id %}"
class="btn btn-secondary btn-sm">
class="btn btn-secondary btn-sm {% if not viewable %}disabled{% endif %}"
{% if not viewable %}aria-disabled="true"{% endif %}>
{% bootstrap_icon "box-arrow-in-down-right" %}
Details
</a>
Expand Down
2 changes: 1 addition & 1 deletion radis/reports/templates/reports/report_detail.html
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ <h4 class="mb-3">Report Details</h4>
</tr>
</tbody>
</table>
<div class="mb-3 mx-1">{% include "reports/_report_buttons_panel.html" with no_view_button=True %}</div>
<div class="mb-3 mx-1">{% include "reports/_report_buttons_panel.html" with hide_view_button=True %}</div>
<div class="card">
<div class="card-body">
<h5 class="card-title">Report Text</h5>
Expand Down
Empty file.
18 changes: 18 additions & 0 deletions radis/reports/templatetags/reports_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any, cast

from django.template import Library

from adit_radis_shared.accounts.models import User

from ..models import Report

register = Library()


@register.simple_tag(takes_context=True)
def can_view_report(context: dict[str, Any], report: Report) -> bool:
user = cast(User, context["request"].user)
active_group = user.active_group
if not active_group:
return False
return report.groups.filter(pk=active_group.pk).exists()
17 changes: 15 additions & 2 deletions radis/reports/views.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin
from django.db.models import QuerySet
from django.views.generic.detail import DetailView

from adit_radis_shared.common.types import AuthenticatedHttpRequest

from .models import Report


class ReportDetailView(LoginRequiredMixin, DetailView):
class ReportDetailView(LoginRequiredMixin, UserPassesTestMixin, DetailView):
model = Report
template_name = "reports/report_detail.html"
context_object_name = "report"
permission_denied_message = "You must be logged in and have an active group"
request: AuthenticatedHttpRequest

def test_func(self) -> bool | None:
return self.request.user.active_group is not None

def get_queryset(self) -> QuerySet[Report]:
active_group = self.request.user.active_group
assert active_group
return super().get_queryset().filter(groups=active_group)


class ReportBodyView(ReportDetailView):
Expand Down
2 changes: 2 additions & 0 deletions radis/search/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ class Search(NamedTuple):
should return the most accurate total count it can calculate.
Attributes:
- group (int): The ID of the group to search.
- query (str): The query to search.
- offset (int): The offset of the search results.
- limit (int): The limit of the search results.
- filters (SearchFilters): The filters to apply to the search.
"""

group: int
query: str
offset: int = 0
limit: int = 10
Expand Down
14 changes: 12 additions & 2 deletions radis/search/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin
from django.core.exceptions import ValidationError
from django.core.paginator import Paginator
from django.http import Http404, HttpRequest
Expand All @@ -14,7 +14,13 @@
from .site import Search, SearchFilters, search_providers


class SearchView(LoginRequiredMixin, View):
class SearchView(LoginRequiredMixin, UserPassesTestMixin, View):
permission_denied_message = "You must be logged in and have an active group"
request: AuthenticatedHttpRequest

def test_func(self) -> bool | None:
return self.request.user.active_group is not None

def get(self, request: AuthenticatedHttpRequest, *args, **kwargs):
form = SearchForm(request.GET)
context: dict[str, Any] = {"form": form}
Expand Down Expand Up @@ -47,8 +53,12 @@ def get(self, request: AuthenticatedHttpRequest, *args, **kwargs):
offset = (page_number - 1) * page_size
context["offset"] = offset

active_group = self.request.user.active_group
assert active_group

if query:
search = Search(
group=active_group.pk,
query=query,
offset=offset,
limit=page_size,
Expand Down
4 changes: 4 additions & 0 deletions radis/vespa/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _execute_query(params: dict[str, Any]) -> VespaQueryResponse:

def search_bm25(search: Search) -> SearchResult:
yql = "select * from sources * where userQuery()"
yql += f" and groups = {search.group}"
filters = build_yql_filter(search.filters)
if filters:
yql += f" {filters}"
Expand All @@ -60,6 +61,7 @@ def search_bm25(search: Search) -> SearchResult:

def search_semantic(search: Search) -> SearchResult:
yql = "select * from sources * where userQuery()"
yql += f" and groups = {search.group}"
filters = build_yql_filter(search.filters)
if filters:
yql += f" {filters}"
Expand Down Expand Up @@ -88,6 +90,7 @@ def search_semantic(search: Search) -> SearchResult:
# https://pyvespa.readthedocs.io/en/latest/getting-started-pyvespa.html#Hybrid-search-with-the-OR-query-operator
def search_hybrid(search: Search) -> SearchResult:
yql = "select * from sources * where userQuery()"
yql += f" and groups = {search.group}"
filters = build_yql_filter(search.filters)
if filters:
yql += f" {filters}"
Expand Down Expand Up @@ -115,6 +118,7 @@ def search_hybrid(search: Search) -> SearchResult:

def retrieve_bm25(search: Search) -> RetrievalResult:
yql = "select * from sources * where userQuery()"
yql += f" and groups = {search.group}"
filters = build_yql_filter(search.filters)
if filters:
yql += f" {filters}"
Expand Down

0 comments on commit e268776

Please sign in to comment.