From 251a2297fe795b7af99c904764a13318dba91f02 Mon Sep 17 00:00:00 2001 From: armantovmasyan Date: Wed, 13 Mar 2024 03:42:06 +0300 Subject: [PATCH] feat: major refactor - login both with username and email - add actions interface for all CRUD operations - create routers for all entities --- src/actions/actions.py | 119 +++++++++++++++++++++++++++++++++++++++++ src/actions/schemas.py | 22 ++++++++ src/auth/db.py | 20 +++++++ src/auth/manager.py | 36 +++++++++++-- src/auth/utils.py | 19 ++++++- src/main.py | 10 +++- src/polls/models.py | 4 +- src/polls/router.py | 109 ++++++++++++++++++++++++++++++++++++- src/polls/schemas.py | 25 ++++++++- 9 files changed, 352 insertions(+), 12 deletions(-) create mode 100644 src/actions/actions.py create mode 100644 src/actions/schemas.py create mode 100644 src/auth/db.py diff --git a/src/actions/actions.py b/src/actions/actions.py new file mode 100644 index 0000000..16f9fc8 --- /dev/null +++ b/src/actions/actions.py @@ -0,0 +1,119 @@ +from typing import Any, Generic, Optional, Type, TypeVar, cast + +from sqlalchemy.future import select +from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import Depends, HTTPException + +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel + +from src.database import Base, get_async_session +from src.auth.utils import map_to_datetime + +import src.actions.schemas as schema +import src.polls.models as models + +ModelType = TypeVar("ModelType", bound=Base) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) + + +class BaseActions(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + def __init__(self, model: Type[ModelType]): + """ + Base class that can be extended by other action classes. + Provides basic CRUD and listing operations. + + :param model: The SQLAlchemy model + :type model: Type[ModelType] + """ + self.model = model + + async def get_all( + self, *, skip: int = 0, limit: int = 100, db: AsyncSession = Depends(get_async_session) + ) -> list[ModelType]: + async with db as session: + query = select(self.model).offset(skip).limit(limit) + result = await session.execute(query) + return cast(list[ModelType], result.scalars().all()) + + async def get(self, id: int, db: AsyncSession = Depends(get_async_session)) -> Optional[ModelType]: + async with db as session: + query = select(self.model).filter(self.model.id == id) + result = await session.execute(query) + obj = result.scalars().first() + if obj: + return obj + else: + raise HTTPException(status_code=404, detail=f"Object with ID {id} not found.") + + async def create(self, *, obj_in: CreateSchemaType, db: AsyncSession = Depends(get_async_session)) -> ModelType: + obj_in_data = map_to_datetime(jsonable_encoder(obj_in)) + db_obj = self.model(**obj_in_data) + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + return db_obj + + async def update( + self, + *, + db_obj: ModelType, + obj_in: UpdateSchemaType | dict[str, Any], + db: AsyncSession = Depends(get_async_session) + ) -> ModelType: + obj_data = jsonable_encoder(db_obj) + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.dict(exclude_unset=True) + + for field in obj_data: + if field in update_data: + setattr(db_obj, field, update_data[field]) + + db.add(db_obj) + await db.commit() + await db.refresh(db_obj) + return db_obj + + async def remove(self, *, id: int, db: AsyncSession = Depends(get_async_session)) -> ModelType: + query = select(self.model).where(self.model.id == id) + result = await db.execute(query) + obj = result.scalars().first() + if obj: + await db.delete(obj) + await db.commit() + return obj + else: + raise HTTPException(status_code=404, detail=f"Object with ID {id} not found.") + + +class PollActions(BaseActions[models.Poll, schema.PostCreate, schema.PostUpdate]): + """Poll actions with basic CRUD operations""" + + pass + + +class QuestionActions(BaseActions[models.Question, schema.PostCreate, schema.PostUpdate]): + """Question actions with basic CRUD operations""" + + pass + + +class ChoiceActions(BaseActions[models.Choice, schema.PostCreate, schema.PostUpdate]): + """Choice actions with basic CRUD operations""" + + pass + + +class VoteActions(BaseActions[models.Vote, schema.PostCreate, schema.PostUpdate]): + """Vote actions with basic CRUD operations""" + + pass + + +poll_action = PollActions(models.Poll) +question_action = QuestionActions(models.Question) +choice_action = QuestionActions(models.Choice) +vote_action = QuestionActions(models.Vote) \ No newline at end of file diff --git a/src/actions/schemas.py b/src/actions/schemas.py new file mode 100644 index 0000000..3cfdf7c --- /dev/null +++ b/src/actions/schemas.py @@ -0,0 +1,22 @@ +from typing import Optional + +from pydantic import BaseModel +# TODO: correct schemas for each entity + + +class HTTPError(BaseModel): + detail: str + + +class PostBase(BaseModel): + title: Optional[str] = None + body: Optional[str] = None + + +class PostCreate(PostBase): + title: str + body: str + + +class PostUpdate(PostBase): + pass diff --git a/src/auth/db.py b/src/auth/db.py new file mode 100644 index 0000000..d1e18c0 --- /dev/null +++ b/src/auth/db.py @@ -0,0 +1,20 @@ +from typing import Optional +from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase +from sqlalchemy import select + +from src.auth.schemas import UserRead + + +class CustomUserDatabase(SQLAlchemyUserDatabase): + """ + Класс для работы с базой данных пользователей, расширенный методом поиска по username. + """ + + def __init__(self, session, user_table) -> None: + super().__init__(session, user_table) + + async def get_by_username(self, username: str) -> Optional[UserRead]: + query = select(self.user_table).where(self.user_table.username == username) + result = await self.session.execute(query) + user = result.scalars().first() + return user diff --git a/src/auth/manager.py b/src/auth/manager.py index 12f67ed..0676d98 100644 --- a/src/auth/manager.py +++ b/src/auth/manager.py @@ -1,9 +1,12 @@ from typing import Optional from fastapi import Depends, Request +from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import (BaseUserManager, IntegerIDMixin, exceptions, models, schemas) +from fastapi_users.db import BaseUserDatabase +from src.auth.db import CustomUserDatabase from src.auth.models import User from src.auth.utils import get_user_db from src.config import SECRET_AUTH @@ -13,12 +16,37 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]): reset_password_token_secret = SECRET_AUTH verification_token_secret = SECRET_AUTH + def __init__(self, user_db: BaseUserDatabase, *args, **kwargs): + super().__init__(user_db, *args, **kwargs) + async def on_after_register(self, user: User, request: Optional[Request] = None): # TODO: add email verification - print(f"User {user.id} has registered.") + print(f"User {user.username} has registered.") + + async def authenticate( + self, + credentials: OAuth2PasswordRequestForm + ) -> Optional[User]: + user = None + if "@" in credentials.username: + user = await self.user_db.get_by_email(credentials.username) + else: + user = await self.user_db.get_by_username(credentials.username) + + if user is None: + self.password_helper.hash(credentials.password) + return None + + verified, updated_password_hash = self.password_helper.verify_and_update( + credentials.password, user.hashed_password + ) + if not verified: + return None + + if updated_password_hash is not None: + await self.user_db.update(user, {"hashed_password": updated_password_hash}) - async def get_by_email(self, email: str) -> Optional[User]: - user = await super() + return user async def create( self, @@ -47,5 +75,5 @@ async def create( return created_user -async def get_user_manager(user_db=Depends(get_user_db)): +async def get_user_manager(user_db: CustomUserDatabase = Depends(get_user_db)) -> UserManager: yield UserManager(user_db) diff --git a/src/auth/utils.py b/src/auth/utils.py index 35ba6ea..055dbec 100644 --- a/src/auth/utils.py +++ b/src/auth/utils.py @@ -1,3 +1,6 @@ +import datetime +from typing import Any + from fastapi import Depends from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from sqlalchemy.ext.asyncio import AsyncSession @@ -5,6 +8,20 @@ from src.auth.models import User from src.database import get_async_session +from src.auth.db import CustomUserDatabase + async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(session, User) + yield CustomUserDatabase(session, User) + + +def map_to_datetime(obj: dict[str, Any]) -> dict[str, Any]: + def str_to_datetime(ts: str) -> datetime.datetime: + return datetime.datetime.fromisoformat(ts) + + for k, v in obj.items(): + if "date" in k: + obj[k] = str_to_datetime(v) + break + + return obj diff --git a/src/main.py b/src/main.py index 9a15690..5916f1d 100644 --- a/src/main.py +++ b/src/main.py @@ -1,5 +1,6 @@ from fastapi import FastAPI -from src.polls.router import router as router_polls +import uvicorn +from src.polls.router import router_polls, router_questions, router_choices, router_votes from src.auth.base_config import fastapi_users, auth_backend from src.auth.schemas import UserCreate, UserRead @@ -20,3 +21,10 @@ ) app.include_router(router_polls) +app.include_router(router_questions) +app.include_router(router_choices) +app.include_router(router_votes) + + +if __name__ == '__main__': + uvicorn.run(app) diff --git a/src/polls/models.py b/src/polls/models.py index ed323c8..d9b66dc 100644 --- a/src/polls/models.py +++ b/src/polls/models.py @@ -3,14 +3,14 @@ import datetime - +# TODO: Response models class Poll(Base): __tablename__ = "poll" id = Column("id", Integer, primary_key=True) title = Column("title", String, nullable=False, unique=True) description = Column("description", String) - created_by = Column("created_by", String, ForeignKey("user.username"), nullable=False) + created_by = Column("created_by", Integer, ForeignKey("user.id"), nullable=False) start_date = Column("start_date", TIMESTAMP, default=datetime.datetime.now()) end_date = Column("end_date", TIMESTAMP, nullable=False) diff --git a/src/polls/router.py b/src/polls/router.py index c05a0f0..6dcb2c5 100644 --- a/src/polls/router.py +++ b/src/polls/router.py @@ -1,4 +1,109 @@ -from fastapi import APIRouter +import pytz -router = APIRouter() +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession +from src.auth.models import User +from src.auth.base_config import fastapi_users +from src.auth.schemas import UserRead + +from src.database import get_async_session + +from src.actions.actions import poll_action, question_action, choice_action, vote_action + +import src.polls.models as model +import src.polls.schemas as schema + +router_polls = APIRouter( + prefix="/polls", + tags=["Polls"] +) + +router_questions = APIRouter( + prefix="/questions", + tags=["Polls"] +) + +router_choices = APIRouter( + prefix="/choices", + tags=["Polls"] +) + +router_votes = APIRouter( + prefix="/votes", + tags=["Polls"] +) + + +async def get_current_active_user( + current_user: User = Depends(fastapi_users.current_user(active=True))): + return current_user + + +@router_polls.post("") +async def create_poll( + new_poll: schema.CreatePoll, + session: AsyncSession = Depends(get_async_session), + current_user: UserRead = Depends(get_current_active_user) +): + created_by_id = current_user.id + if new_poll.end_date.tzinfo is not None: + new_poll.end_date = new_poll.end_date.astimezone(pytz.utc).replace(tzinfo=None) + + new_poll_instance = model.Poll( + title=new_poll.title, + description=new_poll.description, + created_by=created_by_id, + end_date=new_poll.end_date, + ) + + return await poll_action.create(db=session, obj_in=new_poll_instance) + + +@router_polls.get("/{poll_id}", response_model=schema.ReadPoll) +async def get_poll(poll_id: int, session: AsyncSession = Depends(get_async_session)): + pass + + +@router_questions.post("/{poll_id}") +async def create_question( + poll_id: int, + new_question: schema.CreateQuestion, + session: AsyncSession = Depends(get_async_session) +): + new_question_instance = model.Question( + poll_id=poll_id, + question_text=new_question.question_text + ) + + return await question_action.create(db=session, obj_in=new_question_instance) + + +@router_choices.post("/{question_id}") +async def create_choice( + question_id: int, + new_choice: schema.CreateChoice, + session: AsyncSession = Depends(get_async_session) +): + new_choice_instance = model.Choice( + question_id=question_id, + choice_text=new_choice.choice_text + ) + + return await choice_action.create(db=session, obj_in=new_choice_instance) + + +@router_votes.post("/{choice_id}") +async def create_vote( + choice_id: int, + new_vote: schema.CreateVote, + session: AsyncSession = Depends(get_async_session), + current_user: UserRead = Depends(get_current_active_user) +): + new_vote_instance = model.Vote( + choice_id=choice_id, + user_id=current_user.id, + vote_timestamp=new_vote.vote_ts + ) + + return await vote_action.create(db=session, obj_in=new_vote_instance) diff --git a/src/polls/schemas.py b/src/polls/schemas.py index a69b5d0..6f1a2b3 100644 --- a/src/polls/schemas.py +++ b/src/polls/schemas.py @@ -1,3 +1,5 @@ +import datetime + from pydantic import BaseModel, constr, FutureDatetime @@ -9,12 +11,31 @@ class CreatePoll(BaseModel): class CreateQuestion(BaseModel): - id: int question_text: constr(min_length=1, max_length=40) class CreateChoice(BaseModel): - id: int choice_text: constr(min_length=1, max_length=20) +class CreateVote(BaseModel): + vote_ts: datetime.datetime + + +class ReadChoice(BaseModel): + id: int + choice_text: str + + +class ReadQuestion(BaseModel): + id: int + question_text: str + choices: list[ReadChoice] + + +class ReadPoll(BaseModel): + id: int + title: str + description: str + questions: list[ReadQuestion] +