Skip to content

Commit

Permalink
feat: major refactor
Browse files Browse the repository at this point in the history
- login both with username and email
- add actions interface for all CRUD operations
- create routers for all entities
  • Loading branch information
lilpuzeen committed Mar 13, 2024
1 parent c2166c0 commit 251a229
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 12 deletions.
119 changes: 119 additions & 0 deletions src/actions/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Any, Generic, Optional, Type, TypeVar, cast

from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import Depends, HTTPException

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel

from src.database import Base, get_async_session
from src.auth.utils import map_to_datetime

import src.actions.schemas as schema
import src.polls.models as models

ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)


class BaseActions(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
Base class that can be extended by other action classes.
Provides basic CRUD and listing operations.
:param model: The SQLAlchemy model
:type model: Type[ModelType]
"""
self.model = model

async def get_all(
self, *, skip: int = 0, limit: int = 100, db: AsyncSession = Depends(get_async_session)
) -> list[ModelType]:
async with db as session:
query = select(self.model).offset(skip).limit(limit)
result = await session.execute(query)
return cast(list[ModelType], result.scalars().all())

async def get(self, id: int, db: AsyncSession = Depends(get_async_session)) -> Optional[ModelType]:
async with db as session:
query = select(self.model).filter(self.model.id == id)
result = await session.execute(query)
obj = result.scalars().first()
if obj:
return obj
else:
raise HTTPException(status_code=404, detail=f"Object with ID {id} not found.")

async def create(self, *, obj_in: CreateSchemaType, db: AsyncSession = Depends(get_async_session)) -> ModelType:
obj_in_data = map_to_datetime(jsonable_encoder(obj_in))
db_obj = self.model(**obj_in_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj

async def update(
self,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | dict[str, Any],
db: AsyncSession = Depends(get_async_session)
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)

for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])

db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj

async def remove(self, *, id: int, db: AsyncSession = Depends(get_async_session)) -> ModelType:
query = select(self.model).where(self.model.id == id)
result = await db.execute(query)
obj = result.scalars().first()
if obj:
await db.delete(obj)
await db.commit()
return obj
else:
raise HTTPException(status_code=404, detail=f"Object with ID {id} not found.")


class PollActions(BaseActions[models.Poll, schema.PostCreate, schema.PostUpdate]):
"""Poll actions with basic CRUD operations"""

pass


class QuestionActions(BaseActions[models.Question, schema.PostCreate, schema.PostUpdate]):
"""Question actions with basic CRUD operations"""

pass


class ChoiceActions(BaseActions[models.Choice, schema.PostCreate, schema.PostUpdate]):
"""Choice actions with basic CRUD operations"""

pass


class VoteActions(BaseActions[models.Vote, schema.PostCreate, schema.PostUpdate]):
"""Vote actions with basic CRUD operations"""

pass


poll_action = PollActions(models.Poll)
question_action = QuestionActions(models.Question)
choice_action = QuestionActions(models.Choice)
vote_action = QuestionActions(models.Vote)
22 changes: 22 additions & 0 deletions src/actions/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Optional

from pydantic import BaseModel
# TODO: correct schemas for each entity


class HTTPError(BaseModel):
detail: str


class PostBase(BaseModel):
title: Optional[str] = None
body: Optional[str] = None


class PostCreate(PostBase):
title: str
body: str


class PostUpdate(PostBase):
pass
20 changes: 20 additions & 0 deletions src/auth/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Optional
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy import select

from src.auth.schemas import UserRead


class CustomUserDatabase(SQLAlchemyUserDatabase):
"""
Класс для работы с базой данных пользователей, расширенный методом поиска по username.
"""

def __init__(self, session, user_table) -> None:
super().__init__(session, user_table)

async def get_by_username(self, username: str) -> Optional[UserRead]:
query = select(self.user_table).where(self.user_table.username == username)
result = await self.session.execute(query)
user = result.scalars().first()
return user
36 changes: 32 additions & 4 deletions src/auth/manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Optional

from fastapi import Depends, Request
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import (BaseUserManager, IntegerIDMixin, exceptions, models,
schemas)
from fastapi_users.db import BaseUserDatabase

from src.auth.db import CustomUserDatabase
from src.auth.models import User
from src.auth.utils import get_user_db
from src.config import SECRET_AUTH
Expand All @@ -13,12 +16,37 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
reset_password_token_secret = SECRET_AUTH
verification_token_secret = SECRET_AUTH

def __init__(self, user_db: BaseUserDatabase, *args, **kwargs):
super().__init__(user_db, *args, **kwargs)

async def on_after_register(self, user: User, request: Optional[Request] = None):
# TODO: add email verification
print(f"User {user.id} has registered.")
print(f"User {user.username} has registered.")

async def authenticate(
self,
credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
user = None
if "@" in credentials.username:
user = await self.user_db.get_by_email(credentials.username)
else:
user = await self.user_db.get_by_username(credentials.username)

if user is None:
self.password_helper.hash(credentials.password)
return None

verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None

if updated_password_hash is not None:
await self.user_db.update(user, {"hashed_password": updated_password_hash})

async def get_by_email(self, email: str) -> Optional[User]:
user = await super()
return user

async def create(
self,
Expand Down Expand Up @@ -47,5 +75,5 @@ async def create(
return created_user


async def get_user_manager(user_db=Depends(get_user_db)):
async def get_user_manager(user_db: CustomUserDatabase = Depends(get_user_db)) -> UserManager:
yield UserManager(user_db)
19 changes: 18 additions & 1 deletion src/auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
import datetime
from typing import Any

from fastapi import Depends
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession

from src.auth.models import User
from src.database import get_async_session

from src.auth.db import CustomUserDatabase


async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User)
yield CustomUserDatabase(session, User)


def map_to_datetime(obj: dict[str, Any]) -> dict[str, Any]:
def str_to_datetime(ts: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(ts)

for k, v in obj.items():
if "date" in k:
obj[k] = str_to_datetime(v)
break

return obj
10 changes: 9 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import FastAPI
from src.polls.router import router as router_polls
import uvicorn
from src.polls.router import router_polls, router_questions, router_choices, router_votes
from src.auth.base_config import fastapi_users, auth_backend
from src.auth.schemas import UserCreate, UserRead

Expand All @@ -20,3 +21,10 @@
)

app.include_router(router_polls)
app.include_router(router_questions)
app.include_router(router_choices)
app.include_router(router_votes)


if __name__ == '__main__':
uvicorn.run(app)
4 changes: 2 additions & 2 deletions src/polls/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import datetime


# TODO: Response models
class Poll(Base):
__tablename__ = "poll"

id = Column("id", Integer, primary_key=True)
title = Column("title", String, nullable=False, unique=True)
description = Column("description", String)
created_by = Column("created_by", String, ForeignKey("user.username"), nullable=False)
created_by = Column("created_by", Integer, ForeignKey("user.id"), nullable=False)
start_date = Column("start_date", TIMESTAMP, default=datetime.datetime.now())
end_date = Column("end_date", TIMESTAMP, nullable=False)

Expand Down
Loading

0 comments on commit 251a229

Please sign in to comment.