diff --git a/backend/scripts/create_user_token.py b/backend/scripts/create_user_token.py new file mode 100644 index 00000000..5a13c49f --- /dev/null +++ b/backend/scripts/create_user_token.py @@ -0,0 +1,34 @@ +import argparse +import datetime + +from sqlmodel import select +from transcribee_backend.auth import generate_user_token +from transcribee_backend.db import SessionContextManager +from transcribee_backend.helpers.time import now_tz_aware +from transcribee_backend.models.user import User + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--username", required=True) + parser.add_argument("--valid-days", required=True) + args = parser.parse_args() + with SessionContextManager(path="management_command:create_user_token") as session: + valid_days = int(args.valid_days) + if valid_days < 0: + print("Valid days must be positive") + exit(1) + + valid_until = now_tz_aware() + datetime.timedelta(days=valid_days) + + user = session.exec( + select(User).where(User.username == args.username) + ).one_or_none() + + if user is None: + print(f"User {args.user} not found") + exit(1) + + key, user_token = generate_user_token(user, valid_until=valid_until) + session.add(user_token) + print(f"User token created and valid until {valid_until}") + print(f"Secret: {key}") diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index ccc1277d..51fe7150 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,3 +1,4 @@ +import datetime import tempfile from pathlib import Path @@ -14,6 +15,7 @@ from transcribee_backend.config import settings from transcribee_backend.db import get_session from transcribee_backend.exceptions import UserAlreadyExists +from transcribee_backend.helpers.time import now_tz_aware from transcribee_backend.main import app from transcribee_backend.models import User @@ -104,7 +106,9 @@ def user(memory_session: Session): @pytest.fixture def auth_token(user: User, memory_session: Session): - user_token, db_token = generate_user_token(user) + user_token, db_token = generate_user_token( + user, valid_until=now_tz_aware() + datetime.timedelta(days=1) + ) memory_session.add(db_token) memory_session.commit() return user_token @@ -126,7 +130,9 @@ def user_2(memory_session: Session): @pytest.fixture def auth_token_user_2(user_2: User, memory_session: Session): - user_token, db_token = generate_user_token(user_2) + user_token, db_token = generate_user_token( + user_2, valid_until=now_tz_aware() + datetime.timedelta(days=1) + ) memory_session.add(db_token) memory_session.commit() return user_token diff --git a/backend/transcribee_backend/auth.py b/backend/transcribee_backend/auth.py index c5331edd..0177405b 100644 --- a/backend/transcribee_backend/auth.py +++ b/backend/transcribee_backend/auth.py @@ -40,17 +40,18 @@ def pw_cmp(salt, hash, pw, N=14) -> bool: ) -def generate_user_token(user: User): +def generate_user_token(user: User, valid_until: datetime.datetime): raw_token = b64encode(os.urandom(32)).decode() salt, hash = pw_hash( raw_token, N=5 ) # We can use a much lower N here since we do not need to protect against weak passwords token = b64encode(f"{user.id}:{raw_token}".encode()).decode() + return token, UserToken( user_id=user.id, token_hash=hash, token_salt=salt, - valid_until=now_tz_aware() + datetime.timedelta(days=7), + valid_until=valid_until, ) diff --git a/backend/transcribee_backend/routers/user.py b/backend/transcribee_backend/routers/user.py index 4d89d140..a2518ada 100644 --- a/backend/transcribee_backend/routers/user.py +++ b/backend/transcribee_backend/routers/user.py @@ -1,3 +1,5 @@ +import datetime + from fastapi import APIRouter, Depends, HTTPException from sqlmodel import Session, delete from transcribee_proto.api import LoginResponse @@ -12,6 +14,7 @@ ) from transcribee_backend.db import get_session from transcribee_backend.exceptions import UserAlreadyExists +from transcribee_backend.helpers.time import now_tz_aware from transcribee_backend.models import CreateUser, UserBase, UserToken from transcribee_backend.models.user import ChangePasswordRequest @@ -40,7 +43,9 @@ def login(user: CreateUser, session: Session = Depends(get_session)) -> LoginRes except NotAuthorized: raise HTTPException(403) - user_token, db_token = generate_user_token(authorized_user) + user_token, db_token = generate_user_token( + authorized_user, valid_until=now_tz_aware() + datetime.timedelta(days=7) + ) session.add(db_token) session.commit() return LoginResponse(token=user_token)