diff --git a/server/__init__.py b/server/__init__.py index 51d789f..507bbd2 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -32,12 +32,14 @@ REDIS_DSN, UTC, ) -from server import tmpl -from server.auth import callback, login, require_user_login, session_auth_config +from server import auth, contrib, review, tmpl +from server.auth import require_user_login, session_auth_config from server.base import Request, http_client, pg, pg_pool_startup -from server.contrib import delete_patch, suggest_api, suggest_ui from server.model import Patch, PatchState -from server.review import review_patch +from server.router import Router + + +router = Router() class File(NamedTuple): @@ -57,6 +59,7 @@ class File(NamedTuple): content=file_path.read_bytes(), content_type=mimetypes.guess_type(file)[0] ) + @router @litestar.get("/static/{fp:path}", sync_to_thread=False) def static_file_handler(fp: str) -> Response[bytes]: try: @@ -71,6 +74,7 @@ def static_file_handler(fp: str) -> Response[bytes]: else: + @router @litestar.get("/static/{fp:path}", sync_to_thread=True) def static_file_handler(fp: str) -> Response[bytes]: # fp is '/...', so we need to remove prefix make it relative @@ -92,6 +96,7 @@ async def __fetch_users(rows: list[asyncpg.Record]) -> dict[int, asyncpg.Record] return users +@router @litestar.get("/") async def index(request: Request) -> Template: if not request.auth: @@ -120,6 +125,7 @@ async def index(request: Request) -> Template: ) +@router @litestar.get("/contrib/{user_id:int}", guards=[require_user_login]) async def show_user_contrib(user_id: int, request: Request) -> Template: rows = await pg.fetch( @@ -141,6 +147,7 @@ async def show_user_contrib(user_id: int, request: Request) -> Template: ) +@router @litestar.get("/review/{user_id:int}", guards=[require_user_login]) async def show_user_review(user_id: int, request: Request) -> Template: rows = await pg.fetch( @@ -169,6 +176,7 @@ def __index_row_sorter(r: asyncpg.Record) -> tuple[int, datetime]: return 0, r["updated_at"] +@router @litestar.get("/patch/{patch_id:str}") async def get_patch(patch_id: str, request: Request) -> Template: try: @@ -299,17 +307,10 @@ async def startup_fetch_missing_users(*args: Any, **kwargs: Any) -> None: app = litestar.Litestar( [ - index, - show_user_review, - show_user_contrib, - login, - callback, - suggest_ui, - suggest_api, - get_patch, - delete_patch, - review_patch, - static_file_handler, + *auth.router, + *contrib.router, + *review.router, + *router, ], template_config=TemplateConfig( engine=JinjaTemplateEngine.from_environment(tmpl.engine), diff --git a/server/auth.py b/server/auth.py index 8397c34..ddf3f6d 100644 --- a/server/auth.py +++ b/server/auth.py @@ -17,10 +17,13 @@ from config import BGM_TV_APP_ID, BGM_TV_APP_SECRET, SERVER_BASE_URL from server.base import Request, User, http_client, pg +from server.router import Router CALLBACK_URL = f"{SERVER_BASE_URL}/oauth_callback" +router = Router() + async def retrieve_user_from_session( session: dict[str, Any], req: ASGIConnection[Any, Any, Any, Any] @@ -77,6 +80,7 @@ async def authenticate_request( ) +@router @litestar.get("/login", sync_to_thread=False) def login() -> Redirect: return Redirect( @@ -91,6 +95,7 @@ def login() -> Redirect: ) +@router @litestar.get("/oauth_callback") async def callback(code: str, request: Request) -> Redirect: res = await http_client.post( diff --git a/server/contrib.py b/server/contrib.py index c4d9549..21845c5 100644 --- a/server/contrib.py +++ b/server/contrib.py @@ -19,8 +19,13 @@ from config import TURNSTILE_SECRET_KEY, TURNSTILE_SITE_KEY, UTC from server.base import BadRequestException, Request, http_client, pg from server.model import Patch, Wiki +from server.router import Router +router = Router() + + +@router @litestar.get("/suggest") async def suggest_ui(request: Request, subject_id: int = 0) -> Response[Any]: if subject_id == 0: @@ -52,6 +57,7 @@ class CreateSuggestion: nsfw: str | None = None +@router @litestar.post("/suggest") async def suggest_api( subject_id: int, @@ -142,6 +148,7 @@ async def suggest_api( return Redirect(f"/patch/{pk}") +@router @litestar.post("/api/delete-patch/{patch_id:str}") async def delete_patch(patch_id: str, request: Request) -> Redirect: if not request.auth: diff --git a/server/review.py b/server/review.py index acfa7f4..67777cb 100644 --- a/server/review.py +++ b/server/review.py @@ -20,6 +20,10 @@ from server.auth import require_user_editor from server.base import AuthorizedRequest, BadRequestException, User, http_client, pg from server.model import Patch, PatchState +from server.router import Router + + +router = Router() class React(str, enum.Enum): @@ -37,6 +41,7 @@ def __strip_none(d: dict[str, Any]) -> dict[str, Any]: return {key: value for key, value in d.items() if value is not None} +@router @litestar.post("/api/review-patch/{patch_id:str}", guards=[require_user_editor]) async def review_patch( patch_id: str, diff --git a/server/router.py b/server/router.py new file mode 100644 index 0000000..8885516 --- /dev/null +++ b/server/router.py @@ -0,0 +1,21 @@ +from collections.abc import Iterator + +from litestar.types import AnyCallable +from typing_extensions import TypeVar + + +T = TypeVar("T", bound=AnyCallable) + + +class Router: + """A helper class to collect handlers""" + + def __init__(self) -> None: + self.__handler: list[AnyCallable] = [] + + def __call__(self, fn: T) -> T: + self.__handler.append(fn) + return fn + + def __iter__(self) -> Iterator[AnyCallable]: + yield from self.__handler