Skip to content

Commit

Permalink
style: add types
Browse files Browse the repository at this point in the history
  • Loading branch information
trim21 committed Aug 19, 2024
1 parent 756d654 commit e82f667
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 36 deletions.
4 changes: 2 additions & 2 deletions lint/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def __init__(
self.category = category
self.simple = simple

def patch(self):
def patch(self) -> str:
return "\n".join(difflib.unified_diff(self.origin.splitlines(), self.after.splitlines()))

def __str__(self):
def __str__(self) -> str:
return f"<Patch {self.category} {self.message}>"


Expand Down
22 changes: 11 additions & 11 deletions lint/wiki/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Wiki:

__slots__ = "fields", "type"

def __init__(self, type=None, fields: list[Field] | None = None):
def __init__(self, type: str | None = None, fields: list[Field] | None = None):
self.type = type

if fields is None:
Expand Down Expand Up @@ -81,12 +81,12 @@ def __init__(self, lino: int | None = None, line: str | None = None, message: st


class GlobalPrefixError(WikiSyntaxError):
def __init__(self):
def __init__(self) -> None:
super().__init__(message="missing prefix '{{Infobox' at the start")


class GlobalSuffixError(WikiSyntaxError):
def __init__(self):
def __init__(self) -> None:
super().__init__(message="missing '}}' at the end")


Expand All @@ -106,7 +106,7 @@ class ExpectingSignEqualError(WikiSyntaxError):
pass


def try_parse(s):
def try_parse(s: str) -> Wiki:
"""If failed to parse, return zero value"""
try:
return parse(s)
Expand Down Expand Up @@ -192,7 +192,7 @@ def parse(s: str) -> Wiki:
return w


def read_type(s):
def read_type(s: str) -> str:
try:
i = s.index("\n")
except ValueError:
Expand All @@ -201,7 +201,7 @@ def read_type(s):
return _trim_space(s[len(prefix) : i])


def read_array_item(line):
def read_array_item(line: str) -> tuple[str, str]:
"""Read whole line as an array item, spaces are trimmed.
read_array_item("[简体中文名|鲁鲁修]") => "简体中文名", "鲁鲁修"
Expand All @@ -223,7 +223,7 @@ def read_array_item(line):
return "", _trim_space(content)


def read_start_line(line: str):
def read_start_line(line: str) -> tuple[str, str]:
"""Read line without leading '|' as key value pair, spaces are trimmed.
read_start_line("播放日期 = 2017年4月16日") => 播放日期, 2017年4月16日
Expand All @@ -243,15 +243,15 @@ def read_start_line(line: str):
_space_str = " \t"


def _trim_space(s: str):
def _trim_space(s: str) -> str:
return s.strip()


def _trim_left_space(s: str):
def _trim_left_space(s: str) -> str:
return s.strip()


def _trim_right_space(s: str):
def _trim_right_space(s: str) -> str:
return s.strip()


Expand Down Expand Up @@ -294,7 +294,7 @@ def __render(w: Wiki) -> Generator[str, None, None]:
yield "}}"


def __render_items(s: list[Item]):
def __render_items(s: list[Item]) -> Generator[str, None, None]:
for item in s:
if item.key:
yield f"[{item.key}| {item.value}]"
Expand Down
7 changes: 4 additions & 3 deletions lint/wiki/wiki_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Any

import pytest
import yaml
Expand All @@ -9,13 +10,13 @@
spec_repo_path = Path(r"~\proj\bangumi\wiki-syntax-spec").expanduser().resolve()


def test_read_type():
def test_read_type() -> None:
assert not read_type("{{Infobox\n")
assert read_type("{{Infobox Ta\n") == "Ta"
assert read_type("{{Infobox Ta\n}}") == "Ta"


def as_dict(w: Wiki) -> dict:
def as_dict(w: Wiki) -> dict[str, Any]:
data = []
for f in w.fields:
if isinstance(f.value, list):
Expand All @@ -40,7 +41,7 @@ def as_dict(w: Wiki) -> dict:


@pytest.mark.parametrize("name", valid)
def test_bangumi_wiki(name: str):
def test_bangumi_wiki(name: str) -> None:
file = spec_repo_path.joinpath("tests/valid", name)
wiki_raw = file.read_text()
assert as_dict(parse(wiki_raw)) == yaml.safe_load(file.with_suffix(".yaml").read_text()), name
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,7 @@ lines-after-imports = 2
known-first-party = ["lint", "bgm"]

[tool.mypy]
strict = true
warn_no_return = false
check_untyped_defs = true
warn_return_any = false
7 changes: 4 additions & 3 deletions server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from litestar import Response
from litestar.config.csrf import CSRFConfig
from litestar.contrib.jinja import JinjaTemplateEngine
from litestar.datastructures import State
from litestar.exceptions import (
HTTPException,
InternalServerException,
Expand Down Expand Up @@ -55,7 +56,7 @@ class File(NamedTuple):


@litestar.get("/static/{fp:path}", sync_to_thread=False)
def static_file_handler(fp: str) -> Response:
def static_file_handler(fp: str) -> Response[bytes]:
try:
f = static_files[fp]
return Response(
Expand Down Expand Up @@ -141,11 +142,11 @@ async def get_patch(patch_id: uuid.UUID, request: Request) -> Template:
)


def before_req(req: litestar.Request):
def before_req(req: litestar.Request[None, None, State]) -> None:
req.state["now"] = datetime.now(tz=UTC)


def plain_text_exception_handler(_: Request, exc: Exception) -> Response:
def plain_text_exception_handler(_: Request, exc: Exception) -> Template:
"""Default handler for exceptions subclassed from HTTPException."""
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
detail = getattr(exc, "detail", "")
Expand Down
18 changes: 11 additions & 7 deletions server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@
from litestar.types import Empty

from config import BGM_TV_APP_ID, BGM_TV_APP_SECRET, SERVER_BASE_URL
from server.base import http_client
from server.base import Request, http_client
from server.model import User


CALLBACK_URL = f"{SERVER_BASE_URL}/oauth_callback"


async def retrieve_user_from_session(session: dict[str, Any], req: ASGIConnection) -> User | None:
async def retrieve_user_from_session(
session: dict[str, Any], req: ASGIConnection[Any, Any, Any, Any]
) -> User | None:
try:
return __user_from_session(session)
except KeyError:
req.clear_session()


def __user_from_session(session):
def __user_from_session(session: dict[str, Any]) -> User:
return User(
user_id=session["user_id"],
group_id=session["group_id"],
Expand All @@ -41,7 +43,7 @@ def __user_from_session(session):
)


async def refresh(refresh_token) -> dict[str, Any]:
async def refresh(refresh_token: str) -> dict[str, Any]:
async with http_client.post(
"https://bgm.tv/oauth/access_token",
data={
Expand All @@ -57,7 +59,9 @@ async def refresh(refresh_token) -> dict[str, Any]:


class MyAuthenticationMiddleware(SessionAuthMiddleware):
async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult:
async def authenticate_request(
self, connection: ASGIConnection[Any, Any, Any, Any]
) -> AuthenticationResult:
if not connection.session or connection.scope["session"] is Empty:
# the assignment of 'Empty' forces the session middleware to clear session data.
connection.scope["session"] = Empty
Expand Down Expand Up @@ -90,7 +94,7 @@ def login() -> Redirect:


@litestar.get("/oauth_callback")
async def callback(code: str, request: litestar.Request) -> Redirect:
async def callback(code: str, request: Request) -> Redirect:
async with http_client.post(
"https://bgm.tv/oauth/access_token",
data={
Expand Down Expand Up @@ -131,7 +135,7 @@ async def callback(code: str, request: litestar.Request) -> Redirect:
return Redirect("/")


def require_user_editor(connection: ASGIConnection, _):
def require_user_editor(connection: ASGIConnection[Any, Any, Any, Any], _: Any) -> None:
if not connection.auth:
raise NotAuthorizedException
if not connection.auth.allow_edit:
Expand Down
7 changes: 4 additions & 3 deletions server/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterator
from http.cookies import BaseCookie
from typing import Any

Expand Down Expand Up @@ -31,18 +32,18 @@ def update_cookies(self, cookies: LooseCookies, response_url: URL | None = None)
def filter_cookies(self, request_url: URL) -> BaseCookie[str]:
return BaseCookie()

def __len__(self):
def __len__(self) -> int:
return 0

def __iter__(self):
def __iter__(self) -> Iterator[Any]:
yield from ()


http_client = aiohttp.ClientSession(cookie_jar=DisableCookiesJar())
pg = asyncpg.create_pool(dsn=PG_DSN)


async def pg_pool_startup(*args, **kwargs):
async def pg_pool_startup(*args: Any, **kwargs: Any) -> None:
logger.info("init")
await pg

Expand Down
4 changes: 2 additions & 2 deletions server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def is_access_token_fresh(self) -> bool:
return True

@property
def allow_edit(self):
def allow_edit(self) -> bool:
return self.group_id in {2, 11}

@property
def allow_admin(self):
def allow_admin(self) -> bool:
return self.group_id in {2}


Expand Down
7 changes: 4 additions & 3 deletions server/review.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import litestar
import orjson
from asyncpg import Record
from asyncpg.pool import PoolConnectionProxy
from litestar import Response
from litestar.enums import RequestEncodingType
Expand Down Expand Up @@ -38,7 +39,7 @@ async def review_patch(
patch_id: str,
request: AuthorizedRequest,
data: Annotated[ReviewPatch, Body(media_type=RequestEncodingType.URL_ENCODED)],
) -> Response:
) -> Response[Any]:
async with pg.acquire() as conn:
async with conn.transaction():
p = await pg.fetchrow(
Expand All @@ -61,7 +62,7 @@ async def review_patch(
raise NotAuthorizedException("暂不支持")


async def __reject_patch(patch: Patch, conn: PoolConnectionProxy, auth: User) -> Redirect:
async def __reject_patch(patch: Patch, conn: PoolConnectionProxy[Record], auth: User) -> Redirect:
await conn.execute(
"""
update patch set
Expand All @@ -78,7 +79,7 @@ async def __reject_patch(patch: Patch, conn: PoolConnectionProxy, auth: User) ->
return Redirect("/")


async def __accept_patch(patch: Patch, conn: PoolConnectionProxy, auth: User) -> Response:
async def __accept_patch(patch: Patch, conn: PoolConnectionProxy[Record], auth: User) -> Redirect:
if not auth.is_access_token_fresh():
return Redirect("/login")

Expand Down
5 changes: 3 additions & 2 deletions server/tmpl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timedelta
from typing import Any

import jinja2
from jinja2 import pass_context, select_autoescape
Expand All @@ -16,11 +17,11 @@


@pass_context
def rel_time(ctx: Context, value: datetime):
def rel_time(ctx: Context, value: datetime) -> str:
if not isinstance(value, datetime):
raise TypeError("rel_time can be only called with datetime")

req: Request | None = ctx.get("request")
req: Request[Any, Any, Any] | None = ctx.get("request")

if req is None:
return format_duration(datetime.now(tz=UTC) - value)
Expand Down

0 comments on commit e82f667

Please sign in to comment.