diff --git a/CHANGELOG.md b/CHANGELOG.md index 567347c2..58dbc993 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to - Implement `OperationalUnit` schema - Link `OperationalUnit` to `Station` using AFIREV prefixes +- Implement `User` and `Group` schemas ## [0.5.0] - 2024-05-15 diff --git a/src/api/qualicharge/auth/factories.py b/src/api/qualicharge/auth/factories.py index d2047444..067ed859 100644 --- a/src/api/qualicharge/auth/factories.py +++ b/src/api/qualicharge/auth/factories.py @@ -4,13 +4,15 @@ from datetime import datetime from typing import Any, Dict -from polyfactory import PostGenerated +from polyfactory import PostGenerated, Use from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.pytest_plugin import register_fixture from qualicharge.conf import settings +from qualicharge.factories import FrenchDataclassFactory, TimestampedSQLModelFactory from .models import IDToken +from .schemas import Group, User def set_token_exp(name: str, values: Dict[str, int], *args: Any, **kwargs: Any) -> int: @@ -29,3 +31,20 @@ class IDTokenFactory(ModelFactory[IDToken]): iat = int(datetime.now().timestamp()) scope = "email profile" email = "john@doe.com" + + +class UserFactory(TimestampedSQLModelFactory[User]): + """User schema factory.""" + + username = Use( + lambda: FrenchDataclassFactory.__faker__.simple_profile().get("username") + ) + email = Use(FrenchDataclassFactory.__faker__.ascii_company_email) + first_name = Use(FrenchDataclassFactory.__faker__.first_name) + last_name = Use(FrenchDataclassFactory.__faker__.last_name) + + +class GroupFactory(TimestampedSQLModelFactory[Group]): + """Group schema factory.""" + + name = Use(FrenchDataclassFactory.__faker__.company) diff --git a/src/api/qualicharge/auth/schemas.py b/src/api/qualicharge/auth/schemas.py new file mode 100644 index 00000000..629e6fbd --- /dev/null +++ b/src/api/qualicharge/auth/schemas.py @@ -0,0 +1,60 @@ +"""QualiCharge authentication schemas.""" + +from typing import TYPE_CHECKING, Optional +from uuid import UUID, uuid4 + +from pydantic import EmailStr +from sqlalchemy.types import String +from sqlmodel import Field, Relationship, SQLModel + +from qualicharge.schemas import BaseTimestampedSQLModel + +if TYPE_CHECKING: + from qualicharge.schemas.core import OperationalUnit + + +# -- Many-to-many relationships +class UserGroup(SQLModel, table=True): + """M2M User-Group intermediate table.""" + + user_id: UUID = Field(foreign_key="user.id", primary_key=True) + group_id: UUID = Field(foreign_key="group.id", primary_key=True) + + +class GroupOperationalUnit(SQLModel, table=True): + """M2M Group-OperationalUnit intermediate table.""" + + group_id: UUID = Field(foreign_key="group.id", primary_key=True) + operational_unit_id: UUID = Field( + foreign_key="operationalunit.id", primary_key=True + ) + + +# -- Core schemas +class User(BaseTimestampedSQLModel, table=True): + """QualiCharge User.""" + + id: Optional[UUID] = Field(default_factory=lambda: uuid4().hex, primary_key=True) + username: str = Field(unique=True, max_length=150) + email: EmailStr = Field(sa_type=String) + first_name: Optional[str] = Field(max_length=150) + last_name: Optional[str] = Field(max_length=150) + is_active: bool = False + is_staff: bool = False + is_superuser: bool = False + + # Relationships + groups: list["Group"] = Relationship(back_populates="users", link_model=UserGroup) + + +class Group(BaseTimestampedSQLModel, table=True): + """QualiCharge Group.""" + + id: Optional[UUID] = Field(default_factory=lambda: uuid4().hex, primary_key=True) + name: str = Field(unique=True, max_length=150) + + # Relationships + users: list["User"] = Relationship(back_populates="groups", link_model=UserGroup) + operational_units: list["OperationalUnit"] = Relationship( + back_populates="groups", link_model=GroupOperationalUnit + ) diff --git a/src/api/qualicharge/auth/utils.py b/src/api/qualicharge/auth/utils.py new file mode 100644 index 00000000..8f6296bf --- /dev/null +++ b/src/api/qualicharge/auth/utils.py @@ -0,0 +1,21 @@ +"""QualiCharge auth.utils module.""" + +from typing import Sequence, cast + +from sqlalchemy import Column as SAColumn +from sqlalchemy.sql.roles import JoinTargetRole +from sqlmodel import Session as SMSession +from sqlmodel import select + +from qualicharge.schemas.core import OperationalUnit + +from .schemas import Group, User + + +def get_user_operational_units(user: User, session: SMSession) -> Sequence[str]: + """Get user related operational unit codes.""" + return session.exec( + select(OperationalUnit.code) + .join(cast(JoinTargetRole, OperationalUnit.groups)) + .filter(cast(SAColumn, Group.id).in_(group.id for group in user.groups)) + ).all() diff --git a/src/api/qualicharge/migrations/env.py b/src/api/qualicharge/migrations/env.py index f282ebb0..35c54abc 100644 --- a/src/api/qualicharge/migrations/env.py +++ b/src/api/qualicharge/migrations/env.py @@ -9,6 +9,12 @@ from qualicharge.conf import settings # Nota bene: be sure to import all models that need to be migrated here +from qualicharge.auth.schemas import ( # noqa: F401 + User, + UserGroup, + Group, + GroupOperationalUnit, +) from qualicharge.schemas.core import ( # noqa: F401 Amenageur, Enseigne, diff --git a/src/api/qualicharge/migrations/versions/7568f5ff860e_add_user_and_group_schemas.py b/src/api/qualicharge/migrations/versions/7568f5ff860e_add_user_and_group_schemas.py new file mode 100644 index 00000000..b98ec9c5 --- /dev/null +++ b/src/api/qualicharge/migrations/versions/7568f5ff860e_add_user_and_group_schemas.py @@ -0,0 +1,88 @@ +"""add user and group schemas + +Revision ID: 7568f5ff860e +Revises: fda96abb970d +Create Date: 2024-05-20 14:20:28.454872 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "7568f5ff860e" +down_revision: Union[str, None] = "fda96abb970d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "group", + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.CheckConstraint("created_at <= updated_at", name="pre-creation-update"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), + ) + op.create_table( + "user", + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("first_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("last_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("is_staff", sa.Boolean(), nullable=False), + sa.Column("is_superuser", sa.Boolean(), nullable=False), + sa.CheckConstraint("created_at <= updated_at", name="pre-creation-update"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("username"), + ) + op.create_table( + "groupoperationalunit", + sa.Column("group_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("operational_unit_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.ForeignKeyConstraint( + ["group_id"], + ["group.id"], + ), + sa.ForeignKeyConstraint( + ["operational_unit_id"], + ["operationalunit.id"], + ), + sa.PrimaryKeyConstraint("group_id", "operational_unit_id"), + ) + op.create_table( + "usergroup", + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("group_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.ForeignKeyConstraint( + ["group_id"], + ["group.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("user_id", "group_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("usergroup") + op.drop_table("groupoperationalunit") + op.drop_table("user") + op.drop_table("group") + # ### end Alembic commands ### diff --git a/src/api/qualicharge/schemas/core.py b/src/api/qualicharge/schemas/core.py index 26b8576c..0b1a046f 100644 --- a/src/api/qualicharge/schemas/core.py +++ b/src/api/qualicharge/schemas/core.py @@ -26,6 +26,7 @@ from sqlmodel import Session as SMSession from sqlmodel.main import SQLModelConfig +from qualicharge.auth.schemas import Group, GroupOperationalUnit from qualicharge.exceptions import ObjectDoesNotExist from ..models.dynamic import SessionBase, StatusBase @@ -187,7 +188,11 @@ class OperationalUnit(BaseTimestampedSQLModel, table=True): name: str type: OperationalUnitTypeEnum + # Relationships stations: List["Station"] = Relationship(back_populates="operational_unit") + groups: List["Group"] = Relationship( + back_populates="operational_units", link_model=GroupOperationalUnit + ) def create_stations_fk(self, session: SMSession): """Create linked stations foreign keys.""" diff --git a/src/api/tests/auth/__init__.py b/src/api/tests/auth/__init__.py new file mode 100644 index 00000000..e1441c22 --- /dev/null +++ b/src/api/tests/auth/__init__.py @@ -0,0 +1 @@ +"""QualiCharge auth module tests.""" diff --git a/src/api/tests/auth/test_schemas.py b/src/api/tests/auth/test_schemas.py new file mode 100644 index 00000000..03de2fbe --- /dev/null +++ b/src/api/tests/auth/test_schemas.py @@ -0,0 +1,39 @@ +"""Tests for qualicharge.auth.schemas module.""" + +from sqlmodel import select + +from qualicharge.auth.factories import GroupFactory, UserFactory +from qualicharge.auth.schemas import GroupOperationalUnit, UserGroup +from qualicharge.schemas.core import OperationalUnit + + +def test_create_user_group_operational_units(db_session): + """Test the user to operational unit relationship.""" + UserFactory.__session__ = db_session + GroupFactory.__session__ = db_session + + # Create users and groups + user_one, user_two = UserFactory.create_batch_sync(2) + group_one, group_two = GroupFactory.create_batch_sync(2) + db_session.add(UserGroup(user_id=user_one.id, group_id=group_one.id)) + db_session.add(UserGroup(user_id=user_two.id, group_id=group_two.id)) + + assert group_one.users == [ + user_one, + ] + assert group_two.users == [ + user_two, + ] + + # Link group to an operational unit + code = "FRS63" + operational_unit = db_session.exec( + select(OperationalUnit).where(OperationalUnit.code == code) + ).one() + db_session.add( + GroupOperationalUnit( + group_id=group_one.id, operational_unit_id=operational_unit.id + ) + ) + + assert user_one.groups[0].operational_units[0].id == operational_unit.id diff --git a/src/api/tests/auth/test_utils.py b/src/api/tests/auth/test_utils.py new file mode 100644 index 00000000..e3412e55 --- /dev/null +++ b/src/api/tests/auth/test_utils.py @@ -0,0 +1,51 @@ +"""Tests for qualicharge.auth.utils schemas module.""" + +from random import sample +from typing import cast + +from sqlalchemy import Column as SAColumn +from sqlmodel import select + +from qualicharge.auth.factories import GroupFactory, UserFactory +from qualicharge.auth.schemas import GroupOperationalUnit, UserGroup +from qualicharge.auth.utils import get_user_operational_units +from qualicharge.fixtures.operational_units import data as operational_unit_data +from qualicharge.schemas.core import OperationalUnit + + +def test_user_get_operational_units(db_session): + """Test the User get_operational_units utility.""" + UserFactory.__session__ = db_session + GroupFactory.__session__ = db_session + + # Create user, groups and link them (with operational units) + user = UserFactory.create_sync() + n_groups = 8 + groups = GroupFactory.create_batch_sync(n_groups) + user_n_groups = 2 + user_groups = sample(groups, user_n_groups) + operational_unit_codes = [ + operational_unit.code + for operational_unit in sample(operational_unit_data, n_groups) + ] + operational_units = db_session.exec( + select(OperationalUnit).where( + cast(SAColumn, OperationalUnit.code).in_(operational_unit_codes) + ) + ) + db_session.add_all( + UserGroup(user_id=user.id, group_id=group.id) for group in user_groups + ) + db_session.add_all( + GroupOperationalUnit(group_id=group.id, operational_unit_id=operational_unit.id) + for group, operational_unit in zip(groups, operational_units) + ) + + # Get operational unit codes + user_operational_unit_codes = get_user_operational_units(user, db_session) + assert len(user_operational_unit_codes) == user_n_groups + assert set(user_operational_unit_codes) == { + operational_unit.code + for group in user.groups + for operational_unit in group.operational_units + }