Skip to content

Commit

Permalink
🔒️ Add check that users own document
Browse files Browse the repository at this point in the history
  • Loading branch information
pajowu committed May 23, 2023
1 parent 9704d73 commit e3a64b8
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 52 deletions.
55 changes: 44 additions & 11 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool
Expand All @@ -25,27 +26,37 @@ def settings_with_tmpdir():


@pytest.fixture
def client(memory_session: Session):
def app_with_memory_session(memory_session: Session):
def get_session_override():
return memory_session

app.dependency_overrides[get_session] = get_session_override

client = TestClient(app)
yield client
app.dependency_overrides.clear()
yield app

del app.dependency_overrides[get_session]


@pytest.fixture
def logged_in_client(memory_session: Session, auth_token):
def get_session_override():
return memory_session
def client(app_with_memory_session: FastAPI):
client = TestClient(app_with_memory_session)
return client

app.dependency_overrides[get_session] = get_session_override

client = TestClient(app, headers={"Authorization": f"Token {auth_token}"})
yield client
app.dependency_overrides.clear()
@pytest.fixture
def logged_in_client(app_with_memory_session: FastAPI, auth_token: str):
client = TestClient(
app_with_memory_session, headers={"Authorization": f"Token {auth_token}"}
)
return client


@pytest.fixture
def logged_in_client_user_2(app_with_memory_session: FastAPI, auth_token_user_2: str):
client = TestClient(
app_with_memory_session, headers={"Authorization": f"Token {auth_token_user_2}"}
)
return client


@pytest.fixture
Expand Down Expand Up @@ -97,3 +108,25 @@ def auth_token(user: User, memory_session: Session):
memory_session.add(db_token)
memory_session.commit()
return user_token


@pytest.fixture
def user_2(memory_session: Session):
username = "test_user_2"
password = "test_user_2_pass"
try:
user = create_user(session=memory_session, username=username, password=password)
except UserAlreadyExists:
user = change_user_password(
session=memory_session, username=username, new_password=password
)

return user


@pytest.fixture
def auth_token_user_2(user_2: User, memory_session: Session):
user_token, db_token = generate_user_token(user_2)
memory_session.add(db_token)
memory_session.commit()
return user_token
58 changes: 57 additions & 1 deletion backend/tests/test_doc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session
from transcribee_backend.config import settings
Expand All @@ -11,7 +12,25 @@
)


def test_doc_delete(memory_session: Session, logged_in_client: TestClient):
@pytest.fixture
def document(memory_session: Session, logged_in_client: TestClient):
req = logged_in_client.post(
"/api/v1/documents/", files={"file": b""}, data={"name": "test document"}
)
assert req.status_code == 200
document_id = req.json()["id"]

memory_session.add(DocumentUpdate(document_id=document_id, change_bytes=b""))
memory_session.commit()

yield document_id

logged_in_client.delete(f"/api/v1/documents/{document_id}/")


def test_doc_delete(
memory_session: Session, client: TestClient, logged_in_client: TestClient
):
checked_tables = [
Task,
TaskDependency,
Expand Down Expand Up @@ -44,10 +63,47 @@ def test_doc_delete(memory_session: Session, logged_in_client: TestClient):

assert files < set(str(x) for x in settings.storage_path.glob("*"))

req = client.delete(f"/api/v1/documents/{document_id}/")
assert 400 <= req.status_code < 500

req = logged_in_client.delete(f"/api/v1/documents/{document_id}/")
assert req.status_code == 200

for table in checked_tables:
assert counts[table] == memory_session.query(table).count()

assert files == set(str(x) for x in settings.storage_path.glob("*"))


@pytest.mark.parametrize(
"method,url,need_specific_user",
[
["get", "/api/v1/documents/", False],
["get", "/api/v1/documents/{document_id}/", True],
["delete", "/api/v1/documents/{document_id}/", True],
["get", "/api/v1/documents/{document_id}/tasks/", True],
],
)
def test_user_auth(
logged_in_client: TestClient,
logged_in_client_user_2: TestClient,
client: TestClient,
document: str,
method: str,
url: str,
need_specific_user: bool,
):
# Try to access without auth
req = getattr(client, method)(url.format(document_id=document))
assert 400 <= req.status_code < 500

# Try to access with different user
req = getattr(logged_in_client_user_2, method)(url.format(document_id=document))
if need_specific_user:
assert 400 <= req.status_code < 500
else:
assert 200 <= req.status_code < 300

# Try to access with owning user
req = getattr(logged_in_client, method)(url.format(document_id=document))
assert 200 <= req.status_code < 300
4 changes: 3 additions & 1 deletion backend/transcribee_backend/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ class Task(TaskBase, table=True):

is_completed: bool = Field(default=False)
completed_at: Optional[datetime.datetime] = None
completion_data: Optional[Dict] = Field(sa_column=Column(JSON(), nullable=True))
completion_data: Optional[Dict] = Field(
sa_column=Column(JSON(), nullable=True), default=None
)

dependencies: List["Task"] = Relationship(
back_populates="dependants",
Expand Down
94 changes: 55 additions & 39 deletions backend/transcribee_backend/routers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,60 @@
document_router = APIRouter()


def get_document_from_url(
document_id: uuid.UUID,
session: Session = Depends(get_session),
token: UserToken = Depends(get_user_token),
) -> Document:
"""
Get the current document from the `document_id` url parameter, ensuring that the authorized user
is allowed to access the document.
"""
statement = select(Document).where(
Document.id == document_id, Document.user_id == token.user_id
)
doc = session.exec(statement).one_or_none()
if doc is not None:
return doc
else:
raise HTTPException(status_code=404)


def ws_get_document_from_url(
document_id: uuid.UUID,
authorization: str = Query(),
session: Session = Depends(get_session),
):
"""
Get the current document from a websocket url (using the `document_id` url parameter), ensuring
that an authorization query parameter is set and the user / worker can acccess the document.
"""
statement = select(Document).where(Document.id == document_id)
document = session.exec(statement).one_or_none()
if document is None:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)

try:
user_token = validate_user_authorization(session, authorization)
except HTTPException:
user_token = None

try:
worker = validate_worker_authorization(session, authorization)
except HTTPException:
worker = None

if user_token is not None and user_token.user_id == document.user_id:
return document
if worker is not None:
statement = select(Task).where(
Task.assigned_worker_id == worker.id, Task.document_id == document.id
)
if session.exec(statement.limit(1)).one_or_none() is not None:
return document
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)


def create_default_tasks_for_document(session: Session, document: Document):
reencode_task = Task(
task_type=TaskType.REENCODE,
Expand Down Expand Up @@ -130,18 +184,6 @@ def list_documents(
return [doc.as_api_document() for doc in results]


def get_document_from_url(
document_id: uuid.UUID,
session: Session = Depends(get_session),
) -> Document:
statement = select(Document).where(Document.id == document_id)
doc = session.exec(statement).one_or_none()
if doc is not None:
return doc
else:
raise HTTPException(status_code=404)


@document_router.get("/{document_id}/")
def get_document(
token: UserToken = Depends(get_user_token),
Expand Down Expand Up @@ -186,36 +228,10 @@ def get_document_tasks(
return [TaskResponse.from_orm(x) for x in session.exec(statement)]


def can_access_document(
authorization: str = Query(),
document: Document = Depends(get_document_from_url),
session: Session = Depends(get_session),
):
try:
user_token = validate_user_authorization(session, authorization)
except HTTPException:
user_token = None

try:
worker = validate_worker_authorization(session, authorization)
except HTTPException:
worker = None

if user_token is not None and user_token.user_id == document.user_id:
return document
if worker is not None:
statement = select(Task).where(
Task.assigned_worker_id == worker.id, Task.document_id == document.id
)
if session.exec(statement.limit(1)).one_or_none() is not None:
return document
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)


@document_router.websocket("/sync/{document_id}/")
async def websocket_endpoint(
websocket: WebSocket,
document: Document = Depends(can_access_document),
document: Document = Depends(ws_get_document_from_url),
session: Session = Depends(get_session),
):
connection = DocumentSyncConsumer(
Expand Down

0 comments on commit e3a64b8

Please sign in to comment.