diff --git a/config.example.py b/config.example.py index 78ba16e8..3447811b 100644 --- a/config.example.py +++ b/config.example.py @@ -65,7 +65,7 @@ class Config: #: Postgres credentials POSTGRES = {} - + #: Shared secret for LVSP LVSP_SECRET = "" diff --git a/litecord/auth.py b/litecord/auth.py index 01d24af7..1de30b83 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -112,7 +112,7 @@ async def token_check(to_raise: Literal[False] = ...) -> Optional[int]: ... -async def token_check(to_raise = True) -> Optional[int]: +async def token_check(to_raise=True) -> Optional[int]: """Check token information.""" # first, check if the request info already has a uid user_id = getattr(request, "user_id", None) diff --git a/litecord/blueprints/admin_api/channels.py b/litecord/blueprints/admin_api/channels.py index 0aff4efd..6b5a2022 100644 --- a/litecord/blueprints/admin_api/channels.py +++ b/litecord/blueprints/admin_api/channels.py @@ -33,7 +33,6 @@ from quart import current_app as app, request - bp = Blueprint("channels_admin", __name__) diff --git a/litecord/blueprints/admin_api/guilds.py b/litecord/blueprints/admin_api/guilds.py index fb073b7c..5561a444 100644 --- a/litecord/blueprints/admin_api/guilds.py +++ b/litecord/blueprints/admin_api/guilds.py @@ -128,9 +128,7 @@ async def create_guild(): }, ) guild_id = j.get("id") or app.winter_factory.snowflake() - guild, extra = await handle_guild_create( - user_id, guild_id, {"features": j.get("features")} - ) + guild, extra = await handle_guild_create(user_id, guild_id, {"features": j.get("features")}) return jsonify({**guild, **extra}), 201 @@ -166,9 +164,7 @@ async def update_guild(guild_id: int): if old_unavailable and not new_unavailable: # Guild became available guild = await app.storage.get_guild_full(guild_id) - await app.dispatcher.guild.dispatch( - guild_id, ("GUILD_CREATE", {**guild, "unavailable": False}) - ) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", {**guild, "unavailable": False})) elif not old_unavailable and new_unavailable: # Guild became unavailable await app.dispatcher.guild.dispatch( diff --git a/litecord/blueprints/admin_api/info.py b/litecord/blueprints/admin_api/info.py index 5bdd84ff..94899381 100644 --- a/litecord/blueprints/admin_api/info.py +++ b/litecord/blueprints/admin_api/info.py @@ -40,11 +40,7 @@ async def get_db_url(): if host in ("localhost", "0.0.0.0"): host = app.config["MAIN_URL"] - return jsonify( - { - "url": f"postgres://{db['user']}:{db['password']}@{host}:5432/{db['database']}" - } - ) + return jsonify({"url": f"postgres://{db['user']}:{db['password']}@{host}:5432/{db['database']}"}) @bp.route("/snowflake", methods=["GET"]) diff --git a/litecord/blueprints/admin_api/users.py b/litecord/blueprints/admin_api/users.py index 1de5e92c..ee13ff68 100644 --- a/litecord/blueprints/admin_api/users.py +++ b/litecord/blueprints/admin_api/users.py @@ -44,9 +44,7 @@ async def _create_user(): await admin_check() j = validate(await request.get_json(), USER_CREATE) - user_id, _ = await create_user( - j["username"], j["email"], j["password"], j.get("date_of_birth"), id=j.get("id") - ) + user_id, _ = await create_user(j["username"], j["email"], j["password"], j.get("date_of_birth"), id=j.get("id")) return jsonify(await app.storage.get_user(user_id, True)), 201 diff --git a/litecord/blueprints/attachments.py b/litecord/blueprints/attachments.py index 816aafa6..570bcbe7 100644 --- a/litecord/blueprints/attachments.py +++ b/litecord/blueprints/attachments.py @@ -28,9 +28,7 @@ ATTACHMENTS = Path.cwd() / "attachments" -async def _resize_gif( - attach_id: int, resized_path: Path, width: int, height: int -) -> str: +async def _resize_gif(attach_id: int, resized_path: Path, width: int, height: int) -> str: """Resize a GIF attachment.""" # get original gif bytes diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index 668dab1a..44f0a2b0 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -22,7 +22,8 @@ from datetime import datetime, date import itsdangerous import bcrypt -from quart import Blueprint, jsonify, request, current_app as app +from quart import Blueprint, jsonify +from typing import TYPE_CHECKING from logbook import Logger @@ -33,6 +34,11 @@ from litecord.pubsub.user import dispatch_user from .invites import use_invite +if TYPE_CHECKING: + from litecord.typing_hax import app, request +else: + from quart import current_app as app, request + log = Logger(__name__) bp = Blueprint("auth", __name__) @@ -42,9 +48,7 @@ async def check_password(pwd_hash: str, given_password: str) -> bool: pwd_encoded = pwd_hash.encode() given_encoded = given_password.encode() - return await app.loop.run_in_executor( - None, bcrypt.checkpw, given_encoded, pwd_encoded - ) + return await app.loop.run_in_executor(None, bcrypt.checkpw, given_encoded, pwd_encoded) def make_token(user_id, user_pwd_hash) -> str: @@ -86,9 +90,7 @@ async def register(): today = date.today() date_of_birth = datetime.strptime(j["date_of_birth"], "%Y-%m-%d") if ( - today.year - - date_of_birth.year - - ((today.month, today.day) < (date_of_birth.month, date_of_birth.day)) + today.year - date_of_birth.year - ((today.month, today.day) < (date_of_birth.month, date_of_birth.day)) ) < 13: raise ManualFormError( date_of_birth={ @@ -145,9 +147,7 @@ async def _register_with_invite(): today = date.today() date_of_birth = datetime.strptime(data["date_of_birth"], "%Y-%m-%d") if ( - today.year - - date_of_birth.year - - ((today.month, today.day) < (date_of_birth.month, date_of_birth.day)) + today.year - date_of_birth.year - ((today.month, today.day) < (date_of_birth.month, date_of_birth.day)) ) < 13: raise ManualFormError( date_of_birth={ @@ -165,9 +165,7 @@ async def _register_with_invite(): invcode, ) - user_id, pwd_hash = await create_user( - data["username"], data["email"], data["password"], date_of_birth - ) + user_id, pwd_hash = await create_user(data["username"], data["email"], data["password"], date_of_birth) return jsonify({"token": make_token(user_id, pwd_hash)}) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 68f912b6..6604621b 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -110,12 +110,8 @@ async def around_message_search( around_message = await app.storage.get_message(around_id, user_id) around_message = [around_message] if around_message else [] - before_messages = await message_search( - channel_id, halved_limit, before=around_id, order="DESC" - ) - after_messages = await message_search( - channel_id, halved_limit, after=around_id, order="ASC" - ) + before_messages = await message_search(channel_id, halved_limit, before=around_id, order="DESC") + after_messages = await message_search(channel_id, halved_limit, after=around_id, order="ASC") return list(reversed(before_messages)) + around_message + after_messages @@ -132,9 +128,7 @@ async def handle_get_messages(channel_id: int): limit = extract_limit(request, default=50) if "around" in request.args: - messages = await around_message_search( - channel_id, int(request.args["around"]), limit - ) + messages = await around_message_search(channel_id, int(request.args["around"]), limit) else: before, after = query_tuple_from_args(request.args, limit) messages = await message_search(channel_id, limit, before=before, after=after) @@ -216,11 +210,7 @@ async def create_message( mentions = [] mention_roles = [] if data.get("content"): - if ( - allowed_mentions is None - or "users" in allowed_mentions.get("parse", []) - or allowed_mentions.get("users") - ): + if allowed_mentions is None or "users" in allowed_mentions.get("parse", []) or allowed_mentions.get("users"): allowed = (allowed_mentions.get("users") or []) if allowed_mentions else [] if ctype == ChannelType.GROUP_DM: members = await app.db.fetch( @@ -264,9 +254,7 @@ async def create_message( mentions.append(found_id) if actual_guild_id and ( - allowed_mentions is None - or "roles" in allowed_mentions.get("parse", []) - or allowed_mentions.get("roles") + allowed_mentions is None or "roles" in allowed_mentions.get("parse", []) or allowed_mentions.get("roles") ): guild_roles = await app.db.fetch( """ @@ -297,8 +285,7 @@ async def create_message( if ( data.get("message_reference") - and not data.get("flags", 0) & MessageFlags.is_crosspost - == MessageFlags.is_crosspost + and not data.get("flags", 0) & MessageFlags.is_crosspost == MessageFlags.is_crosspost and (allowed_mentions is None or allowed_mentions.get("replied_user", False)) ): reply_id = await app.db.fetchval( @@ -330,9 +317,7 @@ async def create_message( data["tts"], data["everyone_mention"], data["nonce"], - MessageType.DEFAULT.value - if not data.get("message_reference") - else MessageType.REPLY.value, + MessageType.DEFAULT.value if not data.get("message_reference") else MessageType.REPLY.value, data.get("flags") or 0, data.get("embeds") or [], data.get("message_reference") or None, @@ -464,25 +449,16 @@ async def _create_message(channel_id): # guild_id is the dm's peer_id await dm_pre_check(user_id, channel_id, guild_id) - can_everyone = ( - await channel_perm_check(user_id, channel_id, "mention_everyone", False) - and ctype != ChannelType.DM - ) + can_everyone = await channel_perm_check(user_id, channel_id, "mention_everyone", False) and ctype != ChannelType.DM mentions_everyone = ("@everyone" in j["content"]) and can_everyone mentions_here = ("@here" in j["content"]) and can_everyone - is_tts = j.get("tts", False) and await channel_perm_check( - user_id, channel_id, "send_tts_messages", False - ) + is_tts = j.get("tts", False) and await channel_perm_check(user_id, channel_id, "send_tts_messages", False) embeds = [ await fill_embed(embed) - for embed in ( - (j.get("embeds") or []) or [j["embed"]] - if "embed" in j and j["embed"] - else [] - ) + for embed in ((j.get("embeds") or []) or [j["embed"]] if "embed" in j and j["embed"] else []) ] message_id = await create_message( channel_id, @@ -500,10 +476,7 @@ async def _create_message(channel_id): "allowed_mentions": j.get("allowed_mentions"), "sticker_ids": j.get("sticker_ids"), "flags": MessageFlags.suppress_embeds - if ( - j.get("flags", 0) & MessageFlags.suppress_embeds - == MessageFlags.suppress_embeds - ) + if (j.get("flags", 0) & MessageFlags.suppress_embeds == MessageFlags.suppress_embeds) else 0, }, recipient_id=guild_id if ctype == ChannelType.DM else None, @@ -541,9 +514,7 @@ async def _create_message(channel_id): ) if ctype not in (ChannelType.DM, ChannelType.GROUP_DM): - await msg_guild_text_mentions( - payload, guild_id, mentions_everyone, mentions_here - ) + await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here) return jsonify(message_view(payload)) @@ -578,9 +549,7 @@ async def edit_message(channel_id, message_id): old_flags = MessageFlags.from_int(old_message.get("flags", 0)) new_flags = MessageFlags.from_int(int(j["flags"])) - toggle_flag( - old_flags, MessageFlags.suppress_embeds, new_flags.is_suppress_embeds - ) + toggle_flag(old_flags, MessageFlags.suppress_embeds, new_flags.is_suppress_embeds) if old_flags.value != old_message["flags"]: await app.db.execute( @@ -610,11 +579,7 @@ async def edit_message(channel_id, message_id): updated = True embeds = [ await fill_embed(embed) - for embed in ( - (j.get("embeds") or []) or [j["embed"]] - if "embed" in j and j["embed"] - else [] - ) + for embed in ((j.get("embeds") or []) or [j["embed"]] if "embed" in j and j["embed"] else []) ] await app.db.execute( """ @@ -637,9 +602,7 @@ async def edit_message(channel_id, message_id): "channel_id": channel_id, "content": j["content"], "embeds": old_message["embeds"], - "flags": flags - if flags is not None - else old_message.get("flags", 0), + "flags": flags if flags is not None else old_message.get("flags", 0), }, delay=0.2, ) @@ -663,9 +626,7 @@ async def edit_message(channel_id, message_id): await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message)) # now we handle crossposted messages - if updated and ( - message.get("flags", 0) & MessageFlags.crossposted == MessageFlags.crossposted - ): + if updated and (message.get("flags", 0) & MessageFlags.crossposted == MessageFlags.crossposted): async with app.db.acquire() as conn: await pg_set_json(conn) @@ -708,17 +669,13 @@ async def edit_message(channel_id, message_id): "id": id, "channel_id": row["channel_id"], "content": j["content"], - "embeds": embeds - if embeds is not None - else old_message["embeds"], + "embeds": embeds if embeds is not None else old_message["embeds"], }, delay=0.2, ) message = await app.storage.get_message(id) - await app.dispatcher.channel.dispatch( - row["channel_id"], ("MESSAGE_UPDATE", message) - ) + await app.dispatcher.channel.dispatch(row["channel_id"], ("MESSAGE_UPDATE", message)) return jsonify(message_view(message)) @@ -778,9 +735,7 @@ async def _del_msg_fkeys(message_id: int, channel_id: int): ) message = await app.storage.get_message(id) - await app.dispatcher.channel.dispatch( - channel_id, ("MESSAGE_UPDATE", message) - ) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message)) # take the chance and delete all the data from the other tables too! diff --git a/litecord/blueprints/channel/pins.py b/litecord/blueprints/channel/pins.py index 37fb85df..7867b6ed 100644 --- a/litecord/blueprints/channel/pins.py +++ b/litecord/blueprints/channel/pins.py @@ -49,9 +49,7 @@ async def _dispatch_pins_update(channel_id: int) -> None: channel_id, ) - timestamp = ( - app.winter_factory.to_datetime(message_id) if message_id is not None else None - ) + timestamp = app.winter_factory.to_datetime(message_id) if message_id is not None else None await app.dispatcher.channel.dispatch( channel_id, ( @@ -114,9 +112,7 @@ async def add_pin(channel_id, message_id): await _dispatch_pins_update(channel_id) - await send_sys_message( - channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id - ) + await send_sys_message(channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id) return "", 204 diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py index b88fd2ca..207f2ad2 100644 --- a/litecord/blueprints/channel/reactions.py +++ b/litecord/blueprints/channel/reactions.py @@ -177,9 +177,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str): return "", 204 -def emoji_sql( - emoji_type, emoji_id, emoji_name, param_index: int = 4 -) -> Tuple[str, Union[int, str]]: +def emoji_sql(emoji_type, emoji_id, emoji_name, param_index: int = 4) -> Tuple[str, Union[int, str]]: """Extract SQL clauses to search for specific emoji in the message_reactions table.""" param = f"${param_index}" @@ -234,9 +232,7 @@ async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji if ctype in GUILD_CHANS: payload["guild_id"] = str(guild_id) - await app.dispatcher.channel.dispatch( - channel_id, ("MESSAGE_REACTION_REMOVE", payload) - ) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_REACTION_REMOVE", payload)) @bp.route(f"{BASEPATH}//@me", methods=["DELETE"]) @@ -318,6 +314,4 @@ async def remove_all_reactions(channel_id, message_id): if ctype in GUILD_CHANS: payload["guild_id"] = str(guild_id) - await app.dispatcher.channel.dispatch( - channel_id, ("MESSAGE_REACTION_REMOVE_ALL", payload) - ) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_REACTION_REMOVE_ALL", payload)) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 760eba63..252e2989 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -19,9 +19,9 @@ import secrets import time -from typing import List, Optional +from typing import List, Optional, TYPE_CHECKING -from quart import Blueprint, request, current_app as app, jsonify +from quart import Blueprint, jsonify from logbook import Logger from emoji import EMOJI_DATA @@ -58,6 +58,11 @@ from .webhooks import _dispatch_webhook_update from .guilds import handle_search +if TYPE_CHECKING: + from litecord.typing_hax import app, request +else: + from quart import current_app as app, request + log = Logger(__name__) bp = Blueprint("channels", __name__) @@ -358,9 +363,7 @@ async def close_channel(channel_id): ) chan = await app.storage.get_channel(channel_id, user_id=user_id) - await app.dispatcher.channel.dispatch( - channel_id, ("CHANNEL_UPDATE", chan) - ) + await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan)) return jsonify(chan) else: @@ -523,9 +526,7 @@ async def _update_channel_common(channel_id: int, guild_id: int, j: dict): left_shift = new_pos > current_pos # find all channels that we'll have to shift - shift_block: List[Optional[int]] = ( - chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos] - ) + shift_block: List[Optional[int]] = chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos] shift = -1 if left_shift else 1 @@ -567,9 +568,7 @@ async def _update_text_channel(channel_id: int, j: dict, _user_id: int): channel = await app.storage.get_channel(channel_id) # first do the specific ones related to guild_text_channels - for field in [ - field for field in j.keys() if field in ("topic", "rate_limit_per_user") - ]: + for field in [field for field in j.keys() if field in ("topic", "rate_limit_per_user")]: await app.db.execute( f""" UPDATE guild_text_channels @@ -580,10 +579,9 @@ async def _update_text_channel(channel_id: int, j: dict, _user_id: int): channel_id, ) - if channel["type"] in ( - ChannelType.GUILD_TEXT.value, - ChannelType.GUILD_NEWS.value, - ) and j["type"] in (ChannelType.GUILD_TEXT.value, ChannelType.GUILD_NEWS.value): + if channel["type"] in (ChannelType.GUILD_TEXT.value, ChannelType.GUILD_NEWS.value,) and j[ + "type" + ] in (ChannelType.GUILD_TEXT.value, ChannelType.GUILD_NEWS.value): await app.db.execute( f""" UPDATE channels @@ -652,9 +650,7 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int): ) if "icon" in j: - new_icon = await app.icons.update( - "channel_icon", channel_id, j["icon"], always_icon=True - ) + new_icon = await app.icons.update("channel_icon", channel_id, j["icon"], always_icon=True) await app.db.execute( """ @@ -722,9 +718,7 @@ async def trigger_typing(channel_id): "channel_id": str(channel_id), "user_id": str(user_id), "timestamp": int(time.time()), - "guild_id": str(guild_id) - if ctype not in (ChannelType.DM, ChannelType.GROUP_DM) - else None, + "guild_id": str(guild_id) if ctype not in (ChannelType.DM, ChannelType.GROUP_DM) else None, }, ), ) @@ -822,9 +816,7 @@ async def _search_channel(channel_id): await channel_perm_check(user_id, channel_id, "read_messages") await channel_perm_check(user_id, channel_id, "read_history") - return await handle_search( - await app.storage.guild_from_channel(channel_id), channel_id - ) + return await handle_search(await app.storage.guild_from_channel(channel_id), channel_id) @bp.route("//application-commands/search", methods=["GET"]) @@ -878,9 +870,7 @@ async def _msg_unset_flags(message_id: int, unset_flags: int): await _msg_update_flags(message_id, flags) -@bp.route( - "//messages//suppress-embeds", methods=["POST"] -) +@bp.route("//messages//suppress-embeds", methods=["POST"]) async def suppress_embeds(channel_id: int, message_id: int): """Toggle the embeds in a message. @@ -984,9 +974,7 @@ async def publish_message(channel_id: int, message_id: int): "guild_id": message["guild_id"], "flags": message["flags"], } - await app.dispatcher.channel.dispatch( - channel_id, ("MESSAGE_UPDATE", update_payload) - ) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", update_payload)) # Now we execute all these hooks content = message.get("content", "") @@ -999,7 +987,7 @@ async def publish_message(channel_id: int, message_id: int): continue user = await app.storage.get_user(found_id) - content = content.replace(match.group(0), user["username"] if user else "") + content = content.replace(match.group(0), user.username if user else "") result = { "content": content, @@ -1050,9 +1038,7 @@ async def publish_message(channel_id: int, message_id: int): ) payload = await app.storage.get_message(result_id, include_member=True) - await app.dispatcher.channel.dispatch( - hook["channel_id"], ("MESSAGE_CREATE", payload) - ) + await app.dispatcher.channel.dispatch(hook["channel_id"], ("MESSAGE_CREATE", payload)) app.sched.spawn(process_url_embed(payload)) return jsonify(message_view(message)) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 9b856242..e02d45aa 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -50,9 +50,7 @@ async def guild_check(user_id: int, guild_id: int, raise_err: bool = True) -> bo return True -async def guild_owner_check( - user_id: int, guild_id: int, raise_err: bool = True -) -> bool: +async def guild_owner_check(user_id: int, guild_id: int, raise_err: bool = True) -> bool: """Check if a user is the owner of the guild.""" data = await app.db.fetchrow( """ @@ -73,9 +71,7 @@ async def guild_owner_check( return True -async def channel_check( - user_id, channel_id, *, only: Optional[Union[ChannelType, List[ChannelType]]] = None -): +async def channel_check(user_id, channel_id, *, only: Optional[Union[ChannelType, List[ChannelType]]] = None): """Check if the current user is authorized to read the channel's information.""" chan_type = await app.storage.get_chan_type(channel_id) @@ -141,9 +137,7 @@ async def _max_role_position(guild_id, member_id) -> Optional[int]: ) -async def _validate_target_member( - guild_id: int, user_id: int, target_member_id: int -) -> bool: +async def _validate_target_member(guild_id: int, user_id: int, target_member_id: int) -> bool: owner_id = await app.storage.db.fetchval( """ SELECT owner_id diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py index 76aa9ad0..bd354bfb 100644 --- a/litecord/blueprints/dm_channels.py +++ b/litecord/blueprints/dm_channels.py @@ -17,9 +17,9 @@ """ -from typing import Iterable, Optional +from typing import Iterable, Optional, TYPE_CHECKING -from quart import Blueprint, current_app as app, jsonify +from quart import Blueprint, jsonify from logbook import Logger from litecord.blueprints.auth import token_check @@ -31,6 +31,11 @@ from litecord.system_messages import send_sys_message from litecord.pubsub.user import dispatch_user +if TYPE_CHECKING: + from litecord.typing_hax import app +else: + from quart import current_app as app + log = Logger(__name__) bp = Blueprint("dm_channels", __name__) @@ -148,9 +153,7 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None): await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id) -async def gdm_remove_recipient( - channel_id: int, peer_id: int, silent: Optional[bool] = False, *, user_id=None -): +async def gdm_remove_recipient(channel_id: int, peer_id: int, silent: Optional[bool] = False, *, user_id=None): """Remove a member from a GDM. Dispatches: @@ -182,9 +185,7 @@ async def gdm_remove_recipient( author_id = peer_id if user_id is None else user_id if not silent: - await send_sys_message( - channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id - ) + await send_sys_message(channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id) async def gdm_destroy(channel_id): @@ -257,9 +258,7 @@ async def add_to_group_dm(dm_chan, peer_id): # given channel is a gdm # other_id is the peer of the dm if the given channel is a dm - ctype, other_id = await channel_check( - user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM] - ) + ctype, other_id = await channel_check(user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM]) # check relationship with the given user id # and the user id making the request diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py index a8c1919d..82527e4f 100644 --- a/litecord/blueprints/dms.py +++ b/litecord/blueprints/dms.py @@ -47,9 +47,7 @@ async def get_dms(): async def jsonify_dm(dm_id: int, user_id: int): dm_chan = await app.storage.get_dm(dm_id, user_id) - self_user_index = index_by_func( - lambda user: user["id"] == str(user_id), dm_chan["recipients"] - ) + self_user_index = index_by_func(lambda user: user["id"] == str(user_id), dm_chan["recipients"]) if request.discord_api_version > 7: assert self_user_index is not None diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py index 4768dacc..e03012d4 100644 --- a/litecord/blueprints/guild/channels.py +++ b/litecord/blueprints/guild/channels.py @@ -68,9 +68,7 @@ async def create_channel(guild_id): } ) - if channel_type == ChannelType.GUILD_NEWS and not app.storage.has_feature( - guild_id, "NEWS" - ): + if channel_type == ChannelType.GUILD_NEWS and not app.storage.has_feature(guild_id, "NEWS"): raise ManualFormError( type={ "code": "BASE_TYPE_CHOICES", @@ -144,19 +142,13 @@ async def modify_channel_pos(guild_id): j = validate({"channels": raw_j}, CHANNEL_UPDATE_POSITION) j = j["channels"] - channels = { - int(chan["id"]): chan for chan in await app.storage.get_channel_data(guild_id) - } + channels = {int(chan["id"]): chan for chan in await app.storage.get_channel_data(guild_id)} channel_tree = {} for chan in j: conn = await app.db.acquire() _id = int(chan["id"]) - if ( - _id in channels - and "parent_id" in chan - and (chan["parent_id"] is None or chan["parent_id"] in channels) - ): + if _id in channels and "parent_id" in chan and (chan["parent_id"] is None or chan["parent_id"] in channels): channels[_id]["parent_id"] = chan["parent_id"] await conn.execute( """ @@ -180,11 +172,7 @@ async def modify_channel_pos(guild_id): _channel_ids = list(map(lambda chan: int(chan["id"]), _channels)) print(_key, _channel_ids) _channel_positions = {chan["position"]: int(chan["id"]) for chan in _channels} - _change_list = list( - filter( - lambda chan: "position" in chan and int(chan["id"]) in _channel_ids, j - ) - ) + _change_list = list(filter(lambda chan: "position" in chan and int(chan["id"]) in _channel_ids, j)) _swap_pairs = gen_pairs(_change_list, _channel_positions) await _do_channel_updates(guild_id, _swap_pairs) diff --git a/litecord/blueprints/guild/emoji.py b/litecord/blueprints/guild/emoji.py index 62a2abe1..5a332b32 100644 --- a/litecord/blueprints/guild/emoji.py +++ b/litecord/blueprints/guild/emoji.py @@ -112,9 +112,7 @@ async def _put_emoji(guild_id): ) if icon is None: - raise ManualFormError( - image={"code": "IMAGE_INVALID", "message": "Invalid image data"} - ) + raise ManualFormError(image={"code": "IMAGE_INVALID", "message": "Invalid image data"}) # TODO: better way to detect animated emoji rather than just gifs, # maybe a list perhaps? diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py index e4847f11..b9e96bd8 100644 --- a/litecord/blueprints/guild/members.py +++ b/litecord/blueprints/guild/members.py @@ -224,9 +224,7 @@ async def modify_guild_member(guild_id, member_id): partial["nick"] = j["nick"] await app.lazy_guild.pres_update(guild_id, member_id, partial) - await app.dispatcher.guild.dispatch( - guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}) - ) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})) return member @@ -242,6 +240,7 @@ async def update_nickname(guild_id): j = validate(await request.get_json(), SELF_MEMBER_UPDATE) member = await app.storage.get_member(guild_id, user_id) user = await app.storage.get_user(user_id, True) + assert user is not None presence_dict = {} if to_update(j, member, "nick"): @@ -258,7 +257,7 @@ async def update_nickname(guild_id): presence_dict["nick"] = j["nick"] or None if to_update(j, member, "avatar"): - if not j["avatar"] or user["premium_type"] == PremiumType.TIER_2: + if not j["avatar"] or user.premium_type == PremiumType.TIER_2: new_icon = await app.icons.update( "member_avatar", f"{guild_id}_{user_id}", @@ -280,10 +279,8 @@ async def update_nickname(guild_id): presence_dict["avatar"] = new_icon.icon_hash if to_update(j, member, "banner"): - if not j["banner"] or user["premium_type"] == PremiumType.TIER_2: - new_icon = await app.icons.update( - "member_banner", f"{guild_id}_{user_id}", j["banner"], always_icon=True - ) + if not j["banner"] or user.premium_type == PremiumType.TIER_2: + new_icon = await app.icons.update("member_banner", f"{guild_id}_{user_id}", j["banner"], always_icon=True) await app.db.execute( """ @@ -298,7 +295,7 @@ async def update_nickname(guild_id): presence_dict["banner"] = new_icon.icon_hash if to_update(j, member, "bio"): - if not j["bio"] or user["premium_type"] == PremiumType.TIER_2: + if not j["bio"] or user.premium_type == PremiumType.TIER_2: await app.db.execute( """ UPDATE members @@ -329,18 +326,14 @@ async def update_nickname(guild_id): # call pres_update for nick changes, etc. if presence_dict: await app.lazy_guild.pres_update(guild_id, user_id, presence_dict) - await app.dispatcher.guild.dispatch( - guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}) - ) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})) # We inject the guild_id into the payload because the profiles endpoint needs it member["guild_id"] = str(guild_id) return jsonify(member) -@bp.route( - "//members//roles/", methods=["PUT"] -) +@bp.route("//members//roles/", methods=["PUT"]) async def add_member_role(guild_id, member_id, role_id): user_id = await token_check() await guild_perm_check(user_id, guild_id, "manage_roles") @@ -385,9 +378,7 @@ async def add_member_role(guild_id, member_id, role_id): return "", 204 -@bp.route( - "//members//roles/", methods=["DELETE"] -) +@bp.route("//members//roles/", methods=["DELETE"]) async def remove_member_role(guild_id, member_id, role_id): user_id = await token_check() await guild_perm_check(user_id, guild_id, "manage_roles") diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index ac1e6697..1cbfd1bf 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -188,9 +188,7 @@ async def prune_members(user_id, guild_id, member_ids): # calculate permissions against each pruned member, don't prune # if permissions don't allow it for member_id in member_ids: - has_permissions = await guild_perm_check( - user_id, guild_id, "kick_members", member_id, raise_err=False - ) + has_permissions = await guild_perm_check(user_id, guild_id, "kick_members", member_id, raise_err=False) if not has_permissions: continue diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index b59abdd1..8b1b5bb5 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -74,9 +74,7 @@ async def _role_update_dispatch(role_id: int, guild_id: int): await maybe_lazy_guild_dispatch(guild_id, "role_position_update", role) - await app.dispatcher.guild.dispatch( - guild_id, ("GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}) - ) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role})) return role_view(role) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index c6375c80..c3617ac0 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -83,9 +83,7 @@ async def guild_create_roles_prep(guild_id: int, roles: list) -> dict: # from the 2nd and forward, # should be treated as new roles for role in roles[1:]: - cr = await create_role( - guild_id, role.pop("name"), default_perms=default_perms, **role - ) + cr = await create_role(guild_id, role.pop("name"), default_perms=default_perms, **role) if role.get("id") is not None: role_map[role["id"]] = int(cr["id"]) @@ -136,9 +134,7 @@ async def _general_guild_icon(scope: str, guild_id: int, icon: Optional[str], ** async def put_guild_icon(guild_id: int, icon: Optional[str]): """Insert a guild icon on the icon database.""" - return await _general_guild_icon( - "guild_icon", guild_id, icon, size=(1024, 1024), always_icon=True - ) + return await _general_guild_icon("guild_icon", guild_id, icon, size=(1024, 1024), always_icon=True) async def handle_search(guild_id: Optional[int], channel_id: Optional[int] = None): @@ -175,10 +171,7 @@ async def handle_search(guild_id: Optional[int], channel_id: Optional[int] = Non can_read = [channel for channel in j["channel_id"] if channel in can_read] if j.get("mentions"): extra += f" AND content = ANY(${len(args) + 1}::text[])" - args.append( - [f"%<@{id}>%" for id in j["mentions"]] - + [f"%<@!{id}>%" for id in j["mentions"]] - ) + args.append([f"%<@{id}>%" for id in j["mentions"]] + [f"%<@!{id}>%" for id in j["mentions"]]) if j.get("link_hostname"): extra += f" AND content = ANY(${len(args) + 1}::text[])" args.append( @@ -187,9 +180,7 @@ async def handle_search(guild_id: Optional[int], channel_id: Optional[int] = Non ) if j.get("embed_provider"): extra += f" AND embeds::text == ANY(${len(args) + 1}::text[])" - args.append( - ['%"provider": {"name": %s%' % provider for provider in j["embed_provider"]] - ) + args.append(['%"provider": {"name": %s%' % provider for provider in j["embed_provider"]]) if j.get("embed_type"): extra += f" AND embeds::text == ANY(${len(args) + 1}::text[])" args.append(['%"type": %s%' % type for type in j["embed_type"]]) @@ -287,9 +278,7 @@ async def create_guild(): return jsonify(guild), 201 -async def handle_guild_create( - user_id: int, guild_id: int, extra_j: Optional[dict] = None -) -> Tuple[dict, dict]: +async def handle_guild_create(user_id: int, guild_id: int, extra_j: Optional[dict] = None) -> Tuple[dict, dict]: j = validate(await request.get_json(), GUILD_CREATE) extra_j = extra_j or {} @@ -354,9 +343,7 @@ async def handle_guild_create( # create a single #general channel. general_id = guild_id - await create_guild_channel( - guild_id, general_id, ChannelType.GUILD_TEXT, name="general" - ) + await create_guild_channel(guild_id, general_id, ChannelType.GUILD_TEXT, name="general") role_map = {} if j.get("roles"): @@ -484,22 +471,14 @@ async def handle_guild_update(guild_id: int, check: bool = True): if to_update(j, guild, "icon"): await _guild_update_icon("guild_icon", guild_id, j["icon"], size=(1024, 1024)) - if to_update(j, guild, "splash") and await app.storage.has_feature( - guild_id, "INVITE_SPLASH" - ): + if to_update(j, guild, "splash") and await app.storage.has_feature(guild_id, "INVITE_SPLASH"): await _guild_update_icon("guild_splash", guild_id, j["splash"]) - if to_update(j, guild, "banner") and await app.storage.has_feature( - guild_id, "BANNER" - ): + if to_update(j, guild, "banner") and await app.storage.has_feature(guild_id, "BANNER"): await _guild_update_icon("guild_banner", guild_id, j["banner"]) - if to_update(j, guild, "discovery_splash") and await app.storage.has_feature( - guild_id, "DISCOVERABLE" - ): - await _guild_update_icon( - "guild_discovery_splash", guild_id, j["discovery_splash"] - ) + if to_update(j, guild, "discovery_splash") and await app.storage.has_feature(guild_id, "DISCOVERABLE"): + await _guild_update_icon("guild_discovery_splash", guild_id, j["discovery_splash"]) if "features" in j: features = await app.storage.guild_features(guild_id) or [] @@ -583,9 +562,7 @@ async def handle_guild_update(guild_id: int, check: bool = True): await create_guild_channel( guild_id, chan_id, - ChannelType.GUILD_TEXT - if field != "afk_channel_id" - else ChannelType.GUILD_VOICE, + ChannelType.GUILD_TEXT if field != "afk_channel_id" else ChannelType.GUILD_VOICE, name=default_channel_map[field], ) @@ -595,9 +572,7 @@ async def handle_guild_update(guild_id: int, check: bool = True): await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_CREATE", chan)) elif chan["guild_id"] != str(guild_id): - raise ManualFormError( - **{field: {"code": "INVALID_CHANNEL", "message": "Channel is invalid."}} - ) + raise ManualFormError(**{field: {"code": "INVALID_CHANNEL", "message": "Channel is invalid."}}) elif chan is None: await app.db.execute( diff --git a/litecord/blueprints/icons.py b/litecord/blueprints/icons.py index bf8a7879..361ae3e5 100644 --- a/litecord/blueprints/icons.py +++ b/litecord/blueprints/icons.py @@ -18,13 +18,18 @@ """ from os.path import splitext - +from typing import TYPE_CHECKING import aiohttp -from quart import Blueprint, current_app as app, send_file, redirect, make_response +from quart import Blueprint, send_file, redirect, make_response from litecord.embed.sanitizer import make_md_req_url from litecord.embed.schemas import EmbedURL +if TYPE_CHECKING: + from litecord.typing_hax import app, request +else: + from quart import current_app as app, request + bp = Blueprint("images", __name__) @@ -98,17 +103,13 @@ async def _get_avatar_decoration(user_id, avatar_file): @bp.route("/guilds//users//avatars/") async def _get_member_avatar(guild_id, user_id, avatar_file): avatar_hash, ext = splitext_(avatar_file) - return await send_icon( - "member_avatar", f"{guild_id}_{user_id}", avatar_hash, ext=ext - ) + return await send_icon("member_avatar", f"{guild_id}_{user_id}", avatar_hash, ext=ext) @bp.route("/guilds//users//banners/") async def _get_member_banner(guild_id, user_id, banner_file): avatar_hash, ext = splitext_(banner_file) - return await send_icon( - "member_banner", f"{guild_id}_{user_id}", avatar_hash, ext=ext - ) + return await send_icon("member_banner", f"{guild_id}_{user_id}", avatar_hash, ext=ext) # @bp.route('/app-icons//.') diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index 57d0d592..0a62e6a9 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -231,9 +231,7 @@ async def get_invite(invite_code: str): if not inv: raise UnknownInvite() - if request.args.get("with_counts", type=str_bool) or request.args.get( - "with_expiration", type=str_bool - ): + if request.args.get("with_counts", type=str_bool) or request.args.get("with_expiration", type=str_bool): extra = await app.storage.get_invite_extra( invite_code, request.args.get("with_counts", type=str_bool), diff --git a/litecord/blueprints/misc.py b/litecord/blueprints/misc.py index 9cdaaeb6..35d768cc 100644 --- a/litecord/blueprints/misc.py +++ b/litecord/blueprints/misc.py @@ -17,7 +17,8 @@ """ -from quart import Blueprint, jsonify, request, current_app as app +from quart import Blueprint, jsonify +from typing import TYPE_CHECKING import json import secrets @@ -31,6 +32,11 @@ from litecord.blueprints.checks import guild_perm_check from litecord.errors import ManualFormError +if TYPE_CHECKING: + from litecord.typing_hax import app, request +else: + from quart import current_app as app, request + bp = Blueprint("science", __name__) try: @@ -60,9 +66,7 @@ async def experiments(): user_id = await token_check(False) if not user_id and not request.headers.get("X-Fingerprint"): - ret[ - "fingerprint" - ] = f"{app.winter_factory.snowflake()}.{secrets.token_urlsafe(32)}" + ret["fingerprint"] = f"{app.winter_factory.snowflake()}.{secrets.token_urlsafe(32)}" if request.args.get("with_guild_experiments", type=str_bool): ret["guild_experiments"] = await app.storage.get_guild_experiments() @@ -150,7 +154,7 @@ async def partners_apply(): guild_id, ) user = await app.storage.get_user(owner_id) - flags = UserFlags.from_int(user["flags"]) + flags = UserFlags.from_int(user.flags) toggle_flag(flags, UserFlags.partner, True) await app.db.execute( diff --git a/litecord/blueprints/relationships.py b/litecord/blueprints/relationships.py index 9bff6d32..04da162e 100644 --- a/litecord/blueprints/relationships.py +++ b/litecord/blueprints/relationships.py @@ -17,8 +17,9 @@ """ -from quart import Blueprint, jsonify, request, current_app as app +from quart import Blueprint, jsonify from asyncpg import UniqueViolationError +from typing import TYPE_CHECKING from ..auth import token_check from ..schemas import validate, RELATIONSHIP, RELATIONSHIP_UPDATE, SPECIFIC_FRIEND @@ -26,6 +27,10 @@ from litecord.errors import BadRequest from litecord.pubsub.user import dispatch_user +if TYPE_CHECKING: + from litecord.typing_hax import app, request +else: + from quart import current_app as app, request bp = Blueprint("relationship", __name__) @@ -57,9 +62,7 @@ async def _sub_friend(user, peer): await _dispatch_single_pres(int(peer["id"]), user_pres) -async def make_friend( - user_id: int, peer_id: int, rel_type=RelationshipType.FRIEND.value -): +async def make_friend(user_id: int, peer_id: int, rel_type=RelationshipType.FRIEND.value): _friend = RelationshipType.FRIEND.value _block = RelationshipType.BLOCK.value @@ -111,9 +114,7 @@ async def make_friend( _friend, ) - await dispatch_user( - peer_id, ("RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}) - ) + await dispatch_user(peer_id, ("RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)})) await _unsub_friend(user_id, peer_id) @@ -348,18 +349,14 @@ async def remove_relationship(peer_id: int): # if there wasnt any mutual friendship before, # assume they were requests of INCOMING # and OUTGOING. - user_del_type = ( - RelationshipType.OUTGOING.value if incoming_rel_type != _friend else _friend - ) + user_del_type = RelationshipType.OUTGOING.value if incoming_rel_type != _friend else _friend await _dispatch( user_id, ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}), ) - peer_del_type = ( - RelationshipType.INCOMING.value if incoming_rel_type != _friend else _friend - ) + peer_del_type = RelationshipType.INCOMING.value if incoming_rel_type != _friend else _friend await _dispatch( peer_id, @@ -381,9 +378,7 @@ async def remove_relationship(peer_id: int): _block, ) - await _dispatch( - user_id, ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}) - ) + await _dispatch(user_id, ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block})) await _unsub_friend(user_id, peer_id) diff --git a/litecord/blueprints/static.py b/litecord/blueprints/static.py index 7021a600..34251d49 100644 --- a/litecord/blueprints/static.py +++ b/litecord/blueprints/static.py @@ -61,11 +61,9 @@ def _get_environment(app): "API_ENDPOINT": f"//{app.config['MAIN_URL']}/api", "API_VERSION": 9, "WEBAPP_ENDPOINT": f"//{app.config['MAIN_URL']}", - "GATEWAY_ENDPOINT": ("wss://" if app.config["IS_SSL"] else "ws://") - + app.config["WEBSOCKET_URL"], + "GATEWAY_ENDPOINT": ("wss://" if app.config["IS_SSL"] else "ws://") + app.config["WEBSOCKET_URL"], "CDN_HOST": f"//{app.config['MAIN_URL']}", - "ASSET_ENDPOINT": ("https://" if app.config["IS_SSL"] else "http://") - + app.config["MAIN_URL"], + "ASSET_ENDPOINT": ("https://" if app.config["IS_SSL"] else "http://") + app.config["MAIN_URL"], "MEDIA_PROXY_ENDPOINT": f"//{app.config['MEDIA_PROXY']}", "WIDGET_ENDPOINT": f"//{app.config['MAIN_URL']}/widget", "INVITE_HOST": f"{app.config['MAIN_URL']}/invite", @@ -82,10 +80,7 @@ def _get_environment(app): "REMOTE_AUTH_ENDPOINT": "//remote-auth-gateway.discord.gg", "SENTRY_TAGS": {"buildId": "7ea92cf", "buildType": "normal"}, "MIGRATION_SOURCE_ORIGIN": "https://discordapp.com", - "MIGRATION_DESTINATION_ORIGIN": ( - "https://" if app.config["IS_SSL"] else "http://" - ) - + app.config["MAIN_URL"], + "MIGRATION_DESTINATION_ORIGIN": ("https://" if app.config["IS_SSL"] else "http://") + app.config["MAIN_URL"], "HTML_TIMESTAMP": int(time.time() * 1000), "ALGOLIA_KEY": "aca0d7082e4e63af5ba5917d5e96bed0", } @@ -140,9 +135,7 @@ async def _load_build( except KeyError: return "Build not found", 404 - async with aiohttp.request( - "GET", f"https://api.discord.sale/builds/{value}" - ) as resp: + async with aiohttp.request("GET", f"https://api.discord.sale/builds/{value}") as resp: if not 300 > resp.status >= 200: try: info = BUILDS[value] @@ -187,9 +180,7 @@ async def _load_build( resp = await make_response(await render_template(file, **kwargs)) if clear_override: resp.set_cookie("buildOverride", "", expires=0) - if not default and not ( - request.cookies.get("buildOverride") and not clear_override - ): + if not default and not (request.cookies.get("buildOverride") and not clear_override): resp.set_cookie( "buildOverride", await generate_build_override_cookie( @@ -270,16 +261,10 @@ async def _proxy_asset(asset, default: bool = False): try: async with aopen(f"assets/{asset}", "rb") as f: data = await f.read() - response = await make_response( - data, 200, {"content-type": guess_content_type(asset)} - ) + response = await make_response(data, 200, {"content-type": guess_content_type(asset)}) except FileNotFoundError: - async with aiohttp.request( - "GET", f"https://canary.discord.com/assets/{asset}" - ) as resp: - if ( - not 300 > resp.status >= 200 - ): # Fallback to the Wayback Machine if the asset is not found + async with aiohttp.request("GET", f"https://canary.discord.com/assets/{asset}") as resp: + if not 300 > resp.status >= 200: # Fallback to the Wayback Machine if the asset is not found async with aiohttp.request( "GET", f"http://web.archive.org/web/0im_/discordapp.com/assets/{asset}", @@ -325,13 +310,11 @@ async def _proxy_asset(asset, default: bool = False): # Various regexes .replace( r'RegExp("^https://(?:ptb\\.|canary\\.)?(discordapp|discord)\\.com/__development/link?[\\S]+$"', - r'RegExp("^https://%s/__development/link?[\\S]+$"' - % host.replace(".", r"\\."), + r'RegExp("^https://%s/__development/link?[\\S]+$"' % host.replace(".", r"\\."), ) .replace( r"/^((https:\/\/)?(discord\.gg\/)|(discord\.com\/)(invite\/)?)?[A-Za-z0-9]{8,8}$/", - r"/^((https:\/\/)?(%s\/)(invite\/)?)?[A-Za-z0-9]{8,8}$/" - % host.replace(".", r"\."), + r"/^((https:\/\/)?(%s\/)(invite\/)?)?[A-Za-z0-9]{8,8}$/" % host.replace(".", r"\."), ) .replace('+"|discordapp.com|discord.com)$"', f'+"{host})$"') .replace( @@ -417,16 +400,12 @@ async def generate_build_override_link(data: dict) -> str: expiration = ( datetime.fromtimestamp(2147483647) if not j["meta"]["ttl_seconds"] - else ( - datetime.now(tz=timezone.utc) + timedelta(seconds=j["meta"]["ttl_seconds"]) - ) + else (datetime.now(tz=timezone.utc) + timedelta(seconds=j["meta"]["ttl_seconds"])) ) data = { "targetBuildOverride": j["overrides"], "releaseChannel": j["meta"]["release_channel"], - "validForUserIds": [ - str(id) for id in j["meta"].get("valid_for_user_ids") or [] - ], + "validForUserIds": [str(id) for id in j["meta"].get("valid_for_user_ids") or []], "allowLoggedOut": j["meta"]["allow_logged_out"], "expiresAt": format_date_time(time.mktime(expiration.timetuple())), } @@ -453,9 +432,7 @@ async def create_override_link(): if not await is_staff(user_id): return "The maze wasn't meant for you", 403 - return jsonify( - {"url": await generate_build_override_link(await request.get_json())} - ) + return jsonify({"url": await generate_build_override_link(await request.get_json())}) @bp.route("/__development/link", methods=["GET"]) @@ -495,9 +472,7 @@ async def use_overrride_link(): expires_at = datetime(*parsedate(info["expiresAt"])[:6], tzinfo=timezone.utc) # type: ignore if datetime.now(tz=timezone.utc) > expires_at: - return { - "message": "This link has expired. You will need to get a new one." - }, 400 + return {"message": "This link has expired. You will need to get a new one."}, 400 if not info["allowLoggedOut"]: token = j.get("token") @@ -515,9 +490,7 @@ async def use_overrride_link(): resp = jsonify({"message": "Build overrides have been successfully applied!"}) resp.set_cookie( "buildOverride", - await generate_build_override_cookie( - info["targetBuildOverride"], info["expiresAt"] - ), + await generate_build_override_cookie(info["targetBuildOverride"], info["expiresAt"]), expires=expires_at, ) return resp @@ -557,9 +530,7 @@ async def set_build_overrides(): resp = jsonify({"message": "Build overrides have been successfully applied!"}) resp.set_cookie( "buildOverride", - await generate_build_override_cookie( - j["overrides"], format_date_time(2147483647) - ), + await generate_build_override_cookie(j["overrides"], format_date_time(2147483647)), expires=2147483647, ) return resp diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py index b8c88750..a890956d 100644 --- a/litecord/blueprints/user/billing.py +++ b/litecord/blueprints/user/billing.py @@ -184,9 +184,7 @@ async def get_payment_source(user_id: int, source_id: int) -> dict: derow["default"] = derow.pop("default_") derow["billing_address"] = ( - json.loads(derow["billing_address"]) - if isinstance(derow["billing_address"], str) - else derow["billing_address"] + json.loads(derow["billing_address"]) if isinstance(derow["billing_address"], str) else derow["billing_address"] ) source = { @@ -512,11 +510,7 @@ async def _create_subscription(): @bp.route("/@me/billing/subscriptions/", methods=["GET"]) async def _get_subscription(subscription_id): await token_check() - return jsonify( - await get_subscription( - subscription_id or int((await request.get_json())["subscription_id"]) - ) - ) + return jsonify(await get_subscription(subscription_id or int((await request.get_json())["subscription_id"]))) @bp.route("/@me/billing/subscriptions/", methods=["DELETE"]) diff --git a/litecord/blueprints/user/settings.py b/litecord/blueprints/user/settings.py index 0410e115..133c6096 100644 --- a/litecord/blueprints/user/settings.py +++ b/litecord/blueprints/user/settings.py @@ -175,9 +175,7 @@ async def get_note(target_id: int): if note is None: raise NotFound(10013) - return jsonify( - {"user_id": str(user_id), "note_user_id": str(target_id), "note": note} - ) + return jsonify({"user_id": str(user_id), "note_user_id": str(target_id), "note": note}) @bp.route("/@me/notes/", methods=["PUT"]) @@ -209,8 +207,6 @@ async def put_note(target_id: int): note, ) - await dispatch_user( - user_id, ("USER_NOTE_UPDATE", {"id": str(target_id), "note": note}) - ) + await dispatch_user(user_id, ("USER_NOTE_UPDATE", {"id": str(target_id), "note": note})) return "", 204 diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 1bf2e4fb..70230809 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -89,8 +89,7 @@ async def query_users(): async def get_me(): """Get the current user's information.""" user_id = await token_check() - user = await app.storage.get_user(user_id, True) - return jsonify(user) + return await app.storage.get_user(user_id, True) @bp.route("/", methods=["GET"]) @@ -100,7 +99,7 @@ async def get_other(target_id): other = await app.storage.get_user(target_id) if not other: raise NotFound(10013) - return jsonify(other) + return other.to_json(secure=False) async def _try_username_patch(user_id, new_username: str) -> str: @@ -161,7 +160,7 @@ async def _try_discrim_patch(user_id, new_discrim: str): raise BadRequest(30006) -async def _check_pass(j, user): +async def _check_pass(j, user, password_hash: str): # Do not do password checks on unclaimed accounts if user["email"] is None: return @@ -174,8 +173,7 @@ async def _check_pass(j, user): } ) - phash = user["password_hash"] - if not await check_password(phash, j["password"]): + if not await check_password(password_hash, j["password"]): raise ManualFormError( password={ "code": "PASSWORD_DOES_NOT_MATCH", @@ -195,8 +193,8 @@ async def patch_me(): async def handle_user_update(user_id: int, check_password: bool = True): j = validate(await request.get_json(), USER_UPDATE) user = await app.storage.get_user(user_id, True) - - user["password_hash"] = await app.db.fetchval( + assert user is not None + password_hash = await app.db.fetchval( """ SELECT password_hash FROM users @@ -204,31 +202,32 @@ async def handle_user_update(user_id: int, check_password: bool = True): """, user_id, ) + user_dict = user.to_json() - if to_update(j, user, "username"): + if to_update(j, user_dict, "username"): if check_password: - await _check_pass(j, user) + await _check_pass(j, user, password_hash) discrim = await _try_username_patch(user_id, j["username"]) - user["username"] = j["username"] - user["discriminator"] = discrim + user.username = j["username"] + user.discriminator = discrim - if to_update(j, user, "discriminator"): + if to_update(j, user_dict, "discriminator"): if check_password: - await _check_pass(j, user) + await _check_pass(j, user, password_hash) try: new_discrim = "%04d" % int(j["discriminator"]) except (ValueError, TypeError): pass else: - if new_discrim != user["discriminator"]: + if new_discrim != user.discriminator: await _try_discrim_patch(user_id, new_discrim) - user["discriminator"] = new_discrim + user.discriminator = new_discrim - if to_update(j, user, "email"): + if to_update(j, user_dict, "email"): if check_password: - await _check_pass(j, user) + await _check_pass(j, user, password_hash) await app.db.execute( """ @@ -239,8 +238,8 @@ async def handle_user_update(user_id: int, check_password: bool = True): j["email"], user_id, ) - user["email"] = j["email"] - user["verified"] = False + user.email = j["email"] + user.verified = False # only update if values are different # from what the user gave. @@ -252,16 +251,14 @@ async def handle_user_update(user_id: int, check_password: bool = True): # IconManager.update will take care of validating # the value once put()-ing - if to_update(j, user, "avatar"): + if to_update(j, user_dict, "avatar"): mime, _ = parse_data_uri(j["avatar"]) no_gif = False - if mime == "image/gif" and user["premium_type"] == PremiumType.NONE: + if mime == "image/gif" and user.premium_type == PremiumType.NONE: no_gif = True - new_icon = await app.icons.update( - "user_avatar", user_id, j["avatar"], size=(1024, 1024), always_icon=True - ) + new_icon = await app.icons.update("user_avatar", user_id, j["avatar"], size=(1024, 1024), always_icon=True) await app.db.execute( """ @@ -269,14 +266,12 @@ async def handle_user_update(user_id: int, check_password: bool = True): SET avatar = $1 WHERE id = $2 """, - new_icon.icon_hash.lstrip("a_") - if (no_gif and new_icon.icon_hash) - else new_icon.icon_hash, + new_icon.icon_hash.lstrip("a_") if (no_gif and new_icon.icon_hash) else new_icon.icon_hash, user_id, ) - if to_update(j, user, "avatar_decoration"): - if not j["avatar_decoration"] or user["premium_type"] == PremiumType.TIER_2: + if to_update(j, user_dict, "avatar_decoration"): + if not j["avatar_decoration"] or user.premium_type == PremiumType.TIER_2: new_icon = await app.icons.update( "user_avatar_decoration", user_id, @@ -294,11 +289,9 @@ async def handle_user_update(user_id: int, check_password: bool = True): user_id, ) - if to_update(j, user, "banner"): - if not j["banner"] or user["premium_type"] == PremiumType.TIER_2: - new_icon = await app.icons.update( - "user_banner", user_id, j["banner"], always_icon=True - ) + if to_update(j, user_dict, "banner"): + if not j["banner"] or user.premium_type == PremiumType.TIER_2: + new_icon = await app.icons.update("user_banner", user_id, j["banner"], always_icon=True) await app.db.execute( """ @@ -310,7 +303,7 @@ async def handle_user_update(user_id: int, check_password: bool = True): user_id, ) - if to_update(j, user, "bio"): + if to_update(j, user_dict, "bio"): await app.db.execute( """ UPDATE users @@ -321,7 +314,7 @@ async def handle_user_update(user_id: int, check_password: bool = True): user_id, ) - if to_update(j, user, "pronouns"): + if to_update(j, user_dict, "pronouns"): await app.db.execute( """ UPDATE users @@ -341,7 +334,7 @@ async def handle_user_update(user_id: int, check_password: bool = True): except ValueError: pass - if to_update(j, user, "accent_color"): + if to_update(j, user_dict, "accent_color"): await app.db.execute( """ UPDATE users @@ -352,8 +345,8 @@ async def handle_user_update(user_id: int, check_password: bool = True): user_id, ) - if to_update(j, user, "theme_colors"): - if not j["theme_colors"] or user["premium_type"] == PremiumType.TIER_2: + if to_update(j, user_dict, "theme_colors"): + if not j["theme_colors"] or user.premium_type == PremiumType.TIER_2: await app.db.execute( """ UPDATE users @@ -368,7 +361,7 @@ async def handle_user_update(user_id: int, check_password: bool = True): if "new_password" in j and j["new_password"]: if check_password: - await _check_pass(j, user) + await _check_pass(j, user, password_hash) new_hash = await hash_data(j["new_password"]) await app.db.execute( @@ -382,18 +375,14 @@ async def handle_user_update(user_id: int, check_password: bool = True): ) if j.get("flags"): - old_flags = UserFlags.from_int(user["flags"]) + old_flags = UserFlags.from_int(user.flags) new_flags = UserFlags.from_int(j["flags"]) - toggle_flag( - old_flags, UserFlags.premium_dismissed, new_flags.is_premium_dismissed - ) - toggle_flag( - old_flags, UserFlags.unread_urgent_system, new_flags.is_unread_urgent_system - ) + toggle_flag(old_flags, UserFlags.premium_dismissed, new_flags.is_premium_dismissed) + toggle_flag(old_flags, UserFlags.unread_urgent_system, new_flags.is_unread_urgent_system) toggle_flag(old_flags, UserFlags.disable_premium, new_flags.is_disable_premium) - if old_flags.value != user["flags"]: + if old_flags.value != user.flags: await app.db.execute( """ UPDATE users @@ -432,8 +421,6 @@ async def handle_user_update(user_id: int, check_password: bool = True): user_id, ) - user.pop("password_hash") - _, private_user = await mass_user_update(user_id) return private_user @@ -528,9 +515,7 @@ async def get_library(): return jsonify([]) -async def map_guild_ids_to_mutual_list( - mutual_guild_ids: List[int], peer_id: int -) -> List[dict]: +async def map_guild_ids_to_mutual_list(mutual_guild_ids: List[int], peer_id: int) -> List[dict]: mutual_result = [] # ascending sorting @@ -555,7 +540,6 @@ async def get_profile(peer_id: int): """Get a user's profile.""" user_id = await token_check() peer = await app.storage.get_user(peer_id) - if not peer: raise NotFound(10013) @@ -563,7 +547,7 @@ async def get_profile(peer_id: int): friends = await app.user_storage.are_friends_with(user_id, peer_id) staff = await is_staff(user_id) - # don't return a proper card if no guilds are being shared (bypassed by starf) + # don't return a proper card if no guilds are being shared (bypassed by staff) if not mutual_guilds and not friends and not staff: raise MissingAccess() @@ -589,8 +573,8 @@ async def get_profile(peer_id: int): ) result = { - "user": peer, - "user_profile": peer, + "user": peer.to_json(secure=False), + "user_profile": peer.to_json(secure=False), "connected_accounts": [], "premium_type": PLAN_ID_TO_TYPE.get(plan_id), "premium_since": timestamp_(peer_premium), @@ -599,9 +583,7 @@ async def get_profile(peer_id: int): } if request.args.get("with_mutual_guilds", type=str_bool) in (None, True): - result["mutual_guilds"] = await map_guild_ids_to_mutual_list( - mutual_guilds, peer_id - ) + result["mutual_guilds"] = await map_guild_ids_to_mutual_list(mutual_guilds, peer_id) if request.args.get("guild_id", type=int): guild_id = int(request.args["guild_id"]) @@ -614,13 +596,12 @@ async def get_profile(peer_id: int): result["guild_member"] = result["guild_member_profile"] = member_data result["guild_member_profile"]["guild_id"] = str(guild_id) # Husk - if peer["bot"] and not peer["system"]: + if peer.bot and not peer.system: result["application"] = { - "id": peer["id"], + "id": peer.id, "flags": 8667136, "popular_application_command_ids": [], - "verified": peer["flags"] & UserFlags.verified_bot - == UserFlags.verified_bot, + "verified": peer.flags & UserFlags.verified_bot == UserFlags.verified_bot, } return jsonify(result) @@ -730,9 +711,7 @@ async def _get_tinder_score_affinity_users(): # We make semi-accurate affinities by using relationships and private channels friends = await app.user_storage.get_friend_ids(user_id) dms = await app.user_storage.get_dms(user_id) - dm_recipients = [ - r["id"] for dm in dms for r in dm["recipients"] if int(r["id"]) != user_id - ] + dm_recipients = [r["id"] for dm in dms for r in dm["recipients"] if int(r["id"]) != user_id] return jsonify( { "user_affinities": list(set(map(str, friends + dm_recipients))), diff --git a/litecord/blueprints/voice.py b/litecord/blueprints/voice.py index cde0db8b..031a10d1 100644 --- a/litecord/blueprints/voice.py +++ b/litecord/blueprints/voice.py @@ -17,14 +17,19 @@ """ -from typing import Optional +from typing import Optional, TYPE_CHECKING from collections import Counter from random import choice -from quart import Blueprint, jsonify, current_app as app +from quart import Blueprint, jsonify from litecord.blueprints.auth import token_check +if TYPE_CHECKING: + from litecord.typing_hax import app +else: + from quart import current_app as app + bp = Blueprint("voice", __name__) diff --git a/litecord/blueprints/webhooks.py b/litecord/blueprints/webhooks.py index 8f306002..906ffe1d 100644 --- a/litecord/blueprints/webhooks.py +++ b/litecord/blueprints/webhooks.py @@ -19,10 +19,10 @@ import secrets import hashlib -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, TYPE_CHECKING import asyncpg -from quart import Blueprint, jsonify, current_app as app, request +from quart import Blueprint, jsonify from litecord.auth import token_check from litecord.blueprints.checks import ( @@ -59,12 +59,15 @@ from litecord.enums import MessageType from litecord.images import STATIC_IMAGE_MIMES +if TYPE_CHECKING: + from litecord.typing_hax import app, request +else: + from quart import current_app as app, request + bp = Blueprint("webhooks", __name__) -async def get_webhook( - webhook_id: int, *, secure: bool = True -) -> Optional[Dict[str, Any]]: +async def get_webhook(webhook_id: int, *, secure: bool = True) -> Optional[Dict[str, Any]]: """Get a webhook data""" row = await app.db.fetchrow( """ @@ -316,9 +319,7 @@ async def modify_webhook(webhook_id: int): if "channel_id" in j: chan = await app.storage.get_channel(j["channel_id"]) - if (j["channel_id"] and not chan) or ( - chan and chan["guild_id"] != str(guild_id) - ): + if (j["channel_id"] and not chan) or (chan and chan["guild_id"] != str(guild_id)): raise NotFound(10003) await _update_webhook(webhook_id, j) @@ -525,11 +526,7 @@ async def execute_webhook(webhook_id: int, webhook_token): "everyone_mention": mentions_everyone or mentions_here, "embeds": [ await fill_embed(embed) - for embed in ( - (j.get("embeds") or []) or [j["embed"]] - if "embed" in j and j["embed"] - else [] - ) + for embed in ((j.get("embeds") or []) or [j["embed"]] if "embed" in j and j["embed"] else []) ], "info": {"name": j.get("username", webhook["name"]), "avatar": avatar}, }, @@ -578,11 +575,7 @@ async def get_webhook_message(webhook_id, webhook_token, message_id): await webhook_token_check(webhook_id, webhook_token) payload = await app.storage.get_message(message_id) - if ( - not payload - or not payload["webhook_id"] - or int(payload["webhook_id"]) != webhook_id - ): + if not payload or not payload["webhook_id"] or int(payload["webhook_id"]) != webhook_id: raise NotFound(10008) return jsonify(message_view(payload)) @@ -597,11 +590,7 @@ async def update_webhook_message(webhook_id, webhook_token, message_id): _, channel_id = await webhook_token_check(webhook_id, webhook_token) old_message = await app.storage.get_message(message_id) - if ( - not old_message - or not old_message["webhook_id"] - or int(old_message["webhook_id"]) != webhook_id - ): + if not old_message or not old_message["webhook_id"] or int(old_message["webhook_id"]) != webhook_id: raise NotFound(10008) j = validate(await request.get_json(), WEBHOOK_MESSAGE_UPDATE) @@ -624,11 +613,7 @@ async def update_webhook_message(webhook_id, webhook_token, message_id): updated = True embeds = [ await fill_embed(embed) - for embed in ( - (j.get("embeds") or []) or [j["embed"]] - if "embed" in j and j["embed"] - else [] - ) + for embed in ((j.get("embeds") or []) or [j["embed"]] if "embed" in j and j["embed"] else []) ] await app.db.execute( """ @@ -683,11 +668,7 @@ async def delete_webhook_message(webhook_id, webhook_token, message_id): guild_id, channel_id = await webhook_token_check(webhook_id, webhook_token) payload = await app.storage.get_message(message_id) - if ( - not payload - or not payload["webhook_id"] - or int(payload["webhook_id"]) != webhook_id - ): + if not payload or not payload["webhook_id"] or int(payload["webhook_id"]) != webhook_id: raise NotFound(10008) await _del_msg_fkeys(message_id, channel_id) diff --git a/litecord/cache.py b/litecord/cache.py new file mode 100644 index 00000000..f78db2e2 --- /dev/null +++ b/litecord/cache.py @@ -0,0 +1,22 @@ +from typing import Dict +from dataclasses import fields as get_fields +import asyncpg + +from litecord.models import PartialUser + + +class CacheManager: + # User Id: Model + users: Dict[int, PartialUser] + + async def load(self, db: asyncpg.Pool): + fields = [field.name for field in get_fields(PartialUser)] + raw_users = await db.fetchrow( + f""" + SELECT {','.join(fields)} FROM users; + """, + ) + self.users = {raw_user["id"]: PartialUser(**raw_user) for raw_user in raw_users} + + def cache_user(self, user: PartialUser): + self.users[user.id] = user diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index a834e628..a6c74dec 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -147,9 +147,7 @@ async def create_role(guild_id, name: str, **kwargs): # we need to update the lazy guild handlers for the newly created group await maybe_lazy_guild_dispatch(guild_id, "new_role", role) - await app.dispatcher.guild.dispatch( - guild_id, ("GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}) - ) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role})) return role_view(role) @@ -208,9 +206,7 @@ async def _subscribe_users_new_channel(guild_id: int, channel_id: int) -> None: await app.dispatcher.channel.sub(channel_id, session_id) -async def create_guild_channel( - guild_id: int, channel_id: int, ctype: ChannelType, **kwargs -): +async def create_guild_channel(guild_id: int, channel_id: int, ctype: ChannelType, **kwargs): """Create a channel in a guild.""" await app.db.execute( """ @@ -236,9 +232,7 @@ async def create_guild_channel( parent_id = kwargs.get("parent_id") or None - banner = await app.icons.put( - "channel_banner", channel_id, kwargs.get("banner"), always_icon=True - ) + banner = await app.icons.put("channel_banner", channel_id, kwargs.get("banner"), always_icon=True) # all channels go to guild_channels await app.db.execute( @@ -263,9 +257,7 @@ async def create_guild_channel( # This needs to be last, because it depends on users being already sub'd if "permission_overwrites" in kwargs: - await process_overwrites( - guild_id, channel_id, kwargs["permission_overwrites"] or [] - ) + await process_overwrites(guild_id, channel_id, kwargs["permission_overwrites"] or []) async def _del_from_table(table: str, user_id: int): @@ -353,9 +345,7 @@ async def create_guild_settings(guild_id: int, user_id: int): ) -async def add_member( - guild_id: int, user_id: int, *, basic: bool = False, skip_check: bool = False -): +async def add_member(guild_id: int, user_id: int, *, basic: bool = False, skip_check: bool = False): """Add a user to a guild. If `basic` is set to true, side-effects from member adding won't be @@ -377,20 +367,14 @@ async def add_member( features = await app.storage.guild_features(guild_id) or [] user = await app.storage.get_user(user_id, True) - - if ( - "INTERNAL_EMPLOYEE_ONLY" in features - and user["flags"] & UserFlags.staff != UserFlags.staff - ): + assert user is not None + if "INTERNAL_EMPLOYEE_ONLY" in features and user.flags & UserFlags.staff != UserFlags.staff: raise Forbidden(20017) if "INVITES_DISABLED" in features: raise Forbidden(40008) - if ( - nsfw_level in (NSFWLevel.RESTRICTED, NSFWLevel.EXPLICIT) - and not user["nsfw_allowed"] - ): + if nsfw_level in (NSFWLevel.RESTRICTED, NSFWLevel.EXPLICIT) and not user.nsfw_allowed: raise Forbidden(20024) await app.db.execute( @@ -425,15 +409,11 @@ async def add_member( guild_id, ) if system_channel_id: - await send_sys_message( - system_channel_id, MessageType.GUILD_MEMBER_JOIN, user_id - ) + await send_sys_message(system_channel_id, MessageType.GUILD_MEMBER_JOIN, user_id) # tell current members a new member came up member = await app.storage.get_member(guild_id, user_id) - await app.dispatcher.guild.dispatch( - guild_id, ("GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}) - ) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}})) # pubsub changes for new member await app.lazy_guild.new_member(guild_id, user_id) diff --git a/litecord/common/interop.py b/litecord/common/interop.py index bacfed68..579bc970 100644 --- a/litecord/common/interop.py +++ b/litecord/common/interop.py @@ -23,6 +23,7 @@ else: from quart import current_app as app, request + def guild_view(guild_data: dict) -> dict: # Do all the below if applicable if request.discord_api_version < 8: @@ -48,17 +49,9 @@ def channel_view(channel_data: dict) -> dict: for overwrite in channel_data["permission_overwrites"]: overwrite["type"] = "role" if overwrite["type"] == 0 else "member" overwrite["allow_new"] = overwrite.get("allow", "0") - overwrite["allow"] = ( - (int(overwrite["allow"]) & ((2 << 31) - 1)) - if overwrite.get("allow") - else 0 - ) + overwrite["allow"] = (int(overwrite["allow"]) & ((2 << 31) - 1)) if overwrite.get("allow") else 0 overwrite["deny_new"] = overwrite.get("deny", "0") - overwrite["deny"] = ( - (int(overwrite["deny"]) & ((2 << 31) - 1)) - if overwrite.get("deny") - else 0 - ) + overwrite["deny"] = (int(overwrite["deny"]) & ((2 << 31) - 1)) if overwrite.get("deny") else 0 return channel_data diff --git a/litecord/common/messages.py b/litecord/common/messages.py index e1e91067..18903a80 100644 --- a/litecord/common/messages.py +++ b/litecord/common/messages.py @@ -75,19 +75,13 @@ async def msg_create_request() -> tuple: def msg_create_check_content(payload: dict, files: list): """Check if there is actually any content being sent to us.""" content = payload["content"] or "" - embeds = ( - (payload.get("embeds") or []) or [payload["embed"]] - if "embed" in payload and payload["embed"] - else [] - ) + embeds = (payload.get("embeds") or []) or [payload["embed"]] if "embed" in payload and payload["embed"] else [] sticker_ids = payload.get("sticker_ids") if not content and not embeds and not sticker_ids and not files: raise BadRequest(50006) -async def msg_add_attachment( - message_id: int, channel_id: int, author_id: Optional[int], attachment_file -) -> int: +async def msg_add_attachment(message_id: int, channel_id: int, author_id: Optional[int], attachment_file) -> int: """Add an attachment to a message. Parameters @@ -186,9 +180,7 @@ async def msg_add_attachment( return attachment_id -async def msg_guild_text_mentions( - payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool -): +async def msg_guild_text_mentions(payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool): """Calculates mention data side-effects.""" channel_id = int(payload["channel_id"]) diff --git a/litecord/common/users.py b/litecord/common/users.py index 0a595186..8c2600ad 100644 --- a/litecord/common/users.py +++ b/litecord/common/users.py @@ -24,6 +24,8 @@ from asyncpg import UniqueViolationError from logbook import Logger +from litecord.models import PartialUser, User + from ..auth import hash_data from ..errors import BadRequest, ManualFormError from ..presence import BasePresence @@ -38,7 +40,7 @@ log = Logger(__name__) -async def mass_user_update(user_id: int) -> Tuple[dict, dict]: +async def mass_user_update(user_id: int) -> Tuple[PartialUser, User]: """Dispatch a USER_UPDATE to the user itself Dispatches GUILD_MEMBER_UPDATE for others sharing guilds with the user Dispatches PRESENCE_UPDATE for friends outside of guilds @@ -50,8 +52,8 @@ async def mass_user_update(user_id: int) -> Tuple[dict, dict]: lists are they subscribed to. """ public_user = await app.storage.get_user(user_id) - private_user = await app.storage.get_user(user_id, secure=True) - + private_user = await app.storage.get_user(user_id, full=True) + assert public_user and private_user # The user who initiated the profile change should also get possible guild events await dispatch_user(user_id, ("USER_UPDATE", private_user)) @@ -68,7 +70,7 @@ async def mass_user_update(user_id: int) -> Tuple[dict, dict]: presence = app.presence.fetch_self_presence(user_id) # usually this presence should be partial, but there should be no major issue with a full one - await app.presence.dispatch_friends_pres(int(public_user["id"]), presence) + await app.presence.dispatch_friends_pres(int(public_user.id), presence) for guild_id in guild_ids: await app.lazy_guild.update_user(guild_id, user_id) diff --git a/litecord/embed/messages.py b/litecord/embed/messages.py index 7ae403e9..f9548c12 100644 --- a/litecord/embed/messages.py +++ b/litecord/embed/messages.py @@ -92,9 +92,7 @@ async def msg_update_embeds(payload, new_embeds): if "flags" in payload: update_payload["flags"] = payload["flags"] - await app.dispatcher.channel.dispatch( - channel_id, ("MESSAGE_UPDATE", update_payload) - ) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", update_payload)) def is_media_url(url) -> bool: diff --git a/litecord/embed/sanitizer.py b/litecord/embed/sanitizer.py index 79416d95..4e92f7a5 100644 --- a/litecord/embed/sanitizer.py +++ b/litecord/embed/sanitizer.py @@ -109,9 +109,7 @@ def proxify(url) -> str: return make_md_req_url("img", url) -async def _md_client_req( - scope: str, url, *, ret_resp=False -) -> Optional[Union[Tuple, Dict, List[Dict]]]: +async def _md_client_req(scope: str, url, *, ret_resp=False) -> Optional[Union[Tuple, Dict, List[Dict]]]: """Makes a request to the mediaproxy. This has common code between all the main mediaproxy request functions diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index f031a23f..ef5dbcd6 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -150,10 +150,7 @@ async def dispatch(self, event_type: str, event_data: Any) -> None: if event_type in ("MESSAGE_CREATE", "MESSAGE_UPDATE"): data.pop("reactions", None) data["referenced_message"] = data.get("referenced_message") or None - if ( - data.get("type") in (19, 20, 23) - and self.ws.ws_properties.version < 8 - ): + if data.get("type") in (19, 20, 23) and self.ws.ws_properties.version < 8: data["type"] = 0 if not content_allowed(str(self.user_id), self.intents, data): @@ -166,9 +163,7 @@ async def dispatch(self, event_type: str, event_data: Any) -> None: if data["referenced_message"] and not content_allowed( str(self.user_id), self.intents, data["referenced_message"] ): - data["referenced_message"].update( - {"content": "", "embeds": [], "attachments": []} - ) + data["referenced_message"].update({"content": "", "embeds": [], "attachments": []}) elif ( event_type.startswith("GUILD_ROLE_") @@ -179,36 +174,28 @@ async def dispatch(self, event_type: str, event_data: Any) -> None: data["permissions_new"] = data["permissions"] data["permissions"] = int(data["permissions"]) & ((2 << 31) - 1) - elif ( - event_type.startswith("CHANNEL_") - ): + elif event_type.startswith("CHANNEL_"): if data.get("type") == 3: - idx = index_by_func(lambda user: user["id"] == str(self.user_id), data["recipients"]) + idx = index_by_func( + lambda user: user["id"] == str(self.user_id), + data["recipients"], + ) if idx is not None: data["recipients"].pop(idx) if data.get("permission_overwrites") and self.ws.ws_properties.version < 8: for overwrite in data["permission_overwrites"]: - overwrite["type"] = ( - "role" if overwrite["type"] == 0 else "member" - ) + overwrite["type"] = "role" if overwrite["type"] == 0 else "member" overwrite["allow_new"] = overwrite.get("allow", "0") overwrite["allow"] = ( - (int(overwrite["allow"]) & ((2 << 31) - 1)) - if overwrite.get("allow") - else 0 + (int(overwrite["allow"]) & ((2 << 31) - 1)) if overwrite.get("allow") else 0 ) overwrite["deny_new"] = overwrite.get("deny", "0") overwrite["deny"] = ( - (int(overwrite["deny"]) & ((2 << 31) - 1)) - if overwrite.get("deny") - else 0 + (int(overwrite["deny"]) & ((2 << 31) - 1)) if overwrite.get("deny") else 0 ) - elif ( - event_type in ("GUILD_CREATE", "GUILD_UPDATE") - and self.ws.ws_properties.version < 8 - ): + elif event_type in ("GUILD_CREATE", "GUILD_UPDATE") and self.ws.ws_properties.version < 8: for role in data.get("roles", []): role["permissions_new"] = role["permissions"] role["permissions"] = int(role["permissions"]) & ((2 << 31) - 1) @@ -218,13 +205,9 @@ async def dispatch(self, event_type: str, event_data: Any) -> None: "id": overwrite["id"], "type": "role" if overwrite["type"] == 0 else "member", "allow_new": overwrite.get("allow", "0"), - "allow": (int(overwrite["allow"]) & ((2 << 31) - 1)) - if overwrite.get("allow") - else 0, + "allow": (int(overwrite["allow"]) & ((2 << 31) - 1)) if overwrite.get("allow") else 0, "deny_new": overwrite.get("deny", "0"), - "deny": (int(overwrite["deny"]) & ((2 << 31) - 1)) - if overwrite.get("deny") - else 0, + "deny": (int(overwrite["deny"]) & ((2 << 31) - 1)) if overwrite.get("deny") else 0, } for overwrite in channel["permission_overwrites"] ] diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index e4e51039..62c80db0 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -233,9 +233,7 @@ def close(self): """Close the state manager.""" self.closed = True - async def fetch_user_states_for_channel( - self, channel_id: int, user_id: int - ) -> List[GatewayState]: + async def fetch_user_states_for_channel(self, channel_id: int, user_id: int) -> List[GatewayState]: """Get a list of gateway states for a user that can receive events on a certain channel.""" # TODO optimize this with an in-memory store guild_id = await app.storage.guild_from_channel(channel_id) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 646d94d7..f399c9aa 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -76,9 +76,7 @@ log = Logger(__name__) -WebsocketProperties = collections.namedtuple( - "WebsocketProperties", "version encoding compress zctx zsctx tasks" -) +WebsocketProperties = collections.namedtuple("WebsocketProperties", "version encoding compress zctx zsctx tasks") def _complete_users_list(user_id: str, base_ready, user_ready, ws_properties) -> dict: @@ -104,17 +102,13 @@ def _complete_users_list(user_id: str, base_ready, user_ready, ws_properties) -> for private_channel in ready["private_channels"]: if private_channel["type"] == 1: - self_user_index = index_by_func( - lambda user: user["id"] == str(user_id), private_channel["recipients"] - ) + self_user_index = index_by_func(lambda user: user["id"] == str(user_id), private_channel["recipients"]) if ws_properties.version > 7: assert self_user_index is not None private_channel["recipients"].pop(self_user_index) else: if self_user_index == 0: - private_channel["recipients"].append( - private_channel["recipients"].pop(0) - ) + private_channel["recipients"].append(private_channel["recipients"].pop(0)) # if ws_properties.version >= 9: # private_channel["recipient_ids"] = [recipient["id"] for recipient in private_channel["recipients"]], @@ -130,7 +124,9 @@ async def _compute_supplemental(app, base_ready, user_ready, users_to_send: dict "lazy_private_channels": [], } - supplemental["merged_presences"]["friends"] = [{**presence, "last_modified": 0} for presence in user_ready["presences"]] + supplemental["merged_presences"]["friends"] = [ + {**presence, "last_modified": 0} for presence in user_ready["presences"] + ] for guild in base_ready["guilds"]: if not guild.get("unavailable"): @@ -207,9 +203,7 @@ def _set_encoders(self): async def _chunked_send(self, data: bytes, chunk_size: int): """Split data in chunk_size-big chunks and send them over the websocket.""" - log.debug( - "zlib-stream: sending {} bytes into {}-byte chunks", len(data), chunk_size - ) + log.debug("zlib-stream: sending {} bytes into {}-byte chunks", len(data), chunk_size) # we send the entire iterator as per websockets documentation # to pretent setting FIN when we don't want to @@ -238,9 +232,7 @@ async def _zlib_stream_send(self, encoded): await self._chunked_send(data, 1024) async def _zstd_stream_send(self, encoded): - compressor = self.ws_properties.zsctx.stream_writer( - WebsocketFileHandler(self.ws) - ) + compressor = self.ws_properties.zsctx.stream_writer(WebsocketFileHandler(self.ws)) compressor.write(encoded) compressor.flush(zstd.FLUSH_FRAME) @@ -273,21 +265,12 @@ async def send(self, payload: Dict[str, Any]): await self._zlib_stream_send(want_bytes(encoded)) elif self.ws_properties.compress == "zstd-stream": await self._zstd_stream_send(want_bytes(encoded)) - elif ( - self.state - and self.state.compress - and len(encoded) > 8192 - and self.ws_properties.encoding != "etf" - ): + elif self.state and self.state.compress and len(encoded) > 8192 and self.ws_properties.encoding != "etf": # TODO determine better conditions to trigger a compress set # by identify await self.ws.send(zlib.compress(want_bytes(encoded))) else: - await self.ws.send( - want_bytes(encoded) - if self.ws_properties.encoding == "etf" - else want_string(encoded) - ) + await self.ws.send(want_bytes(encoded) if self.ws_properties.encoding == "etf" else want_string(encoded)) async def send_op(self, op_code: int, data: Any): """Send a packet but just the OP code information is filled in.""" @@ -314,16 +297,12 @@ def _hb_start(self, interval: int): if task: task.cancel() - self.ws_properties.tasks["heartbeat"] = app.sched.spawn( - task_wrapper("hb wait", self._hb_wait(interval)) - ) + self.ws_properties.tasks["heartbeat"] = app.sched.spawn(task_wrapper("hb wait", self._hb_wait(interval))) async def _send_hello(self): """Send the OP 10 Hello packet over the websocket.""" # random heartbeat intervals - await self.send_op( - OP.HELLO, {"heartbeat_interval": 41250, "_trace": ["litecord"]} - ) + await self.send_op(OP.HELLO, {"heartbeat_interval": 41250, "_trace": ["litecord"]}) self._hb_start(41250) @@ -346,9 +325,7 @@ async def dispatch_raw(self, event: str, data: Any): try: await self.send(payload) except websockets.exceptions.ConnectionClosed: - log.warning( - "Failed to dispatch {!r} to {}", event.upper, self.state.session_id - ) + log.warning("Failed to dispatch {!r} to {}", event.upper, self.state.session_id) async def _make_guild_list(self) -> List[Dict[str, Any]]: assert self.state is not None @@ -357,9 +334,7 @@ async def _make_guild_list(self) -> List[Dict[str, Any]]: if self.state.bot: return [{"id": row, "unavailable": True} for row in guild_ids] - return await self.storage.get_guilds( - guild_ids, self.state.user_id, True, large=self.state.large - ) + return await self.storage.get_guilds(guild_ids, self.state.user_id, True, large=self.state.large) async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]): """Dispatch GUILD_CREATE information.""" @@ -370,9 +345,7 @@ async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]): return guild_ids = [int(g["id"]) for g in unavailable_guilds] - guilds = await self.storage.get_guilds( - guild_ids, self.state.user_id, True, large=self.state.large - ) + guilds = await self.storage.get_guilds(guild_ids, self.state.user_id, True, large=self.state.large) for guild in guilds: await self.dispatch_raw("GUILD_CREATE", {**guild, "unavailable": False}) @@ -382,11 +355,7 @@ async def _user_ready(self, *, settings=None) -> dict: assert self.state is not None user_id = self.state.user_id relationships = await self.user_storage.get_relationships(user_id) - friend_users = [ - r["user"] - for r in relationships - if r["type"] == RelationshipType.FRIEND.value - ] + friend_users = [r["user"] for r in relationships if r["type"] == RelationshipType.FRIEND.value] friend_presences = await self.app.presence.friend_presences(friend_users) settings = settings or await self.user_storage.get_user_settings(user_id) @@ -440,15 +409,13 @@ async def dispatch_ready(self, **kwargs): user_id = self.state.user_id user = await self.storage.get_user(user_id, True) - + assert user user_ready = {} if not self.state.bot: # user, fetch info user_ready = await self._user_ready(**kwargs) - private_channels = await self.user_storage.get_dms( - user_id - ) + await self.user_storage.get_gdms(user_id) + private_channels = await self.user_storage.get_dms(user_id) + await self.user_storage.get_gdms(user_id) base_ready = { "v": self.ws_properties.version, @@ -472,14 +439,13 @@ async def dispatch_ready(self, **kwargs): # pass users_to_send to ready_supplemental so that its easier to # cross-reference things - full_ready_data, users_to_send = _complete_users_list( - user["id"], base_ready, user_ready, self.ws_properties - ) - ready_supplemental = await _compute_supplemental( - self.app, base_ready, user_ready, users_to_send - ) + full_ready_data, users_to_send = _complete_users_list(str(user.id), base_ready, user_ready, self.ws_properties) + ready_supplemental = await _compute_supplemental(self.app, base_ready, user_ready, users_to_send) - full_ready_data["merged_members"] = [[member for member in members if member["user"]["id"] == user["id"]] for members in ready_supplemental["merged_members"]] + full_ready_data["merged_members"] = [ + [member for member in members if member["user"]["id"] == user.id] + for members in ready_supplemental["merged_members"] + ] if self.ws_properties.version < 6: # Extremely old client compat for guild in full_ready_data["guilds"]: @@ -608,14 +574,10 @@ async def update_presence( if not self.state: return - if not override_ratelimit and self._check_ratelimit( - "presence", self.state.session_id - ): + if not override_ratelimit and self._check_ratelimit("presence", self.state.session_id): return - settings = settings or await self.user_storage.get_user_settings( - self.state.user_id - ) + settings = settings or await self.user_storage.get_user_settings(self.state.user_id) presence = BasePresence(status=(settings["status"] or "online"), game=None) @@ -997,9 +959,7 @@ async def handle_6(self, payload: Dict[str, Any]): await self._resume(range(seq, state.seq)) - async def _req_guild_members( - self, guild_id, user_ids: List[int], query: str, limit: int, presences: bool - ): + async def _req_guild_members(self, guild_id, user_ids: List[int], query: str, limit: int, presences: bool): try: guild_id = int(guild_id) except (TypeError, ValueError): @@ -1020,9 +980,7 @@ async def _req_guild_members( # ASSUMPTION: requesting user_ids means we don't do query. if user_ids: - log.debug( - "req guild members: getting {} users in gid {}", len(user_ids), guild_id - ) + log.debug("req guild members: getting {} users in gid {}", len(user_ids), guild_id) members = await self.storage.get_member_multi(guild_id, user_ids) mids = [int(m["user"]["id"]) for m in members] @@ -1228,9 +1186,7 @@ async def handle_14(self, payload: Dict[str, Any]): member_list = await app.lazy_guild.get_gml(chan_id) - perms = await get_permissions( - self.state.user_id, chan_id, storage=self.storage - ) + perms = await get_permissions(self.state.user_id, chan_id, storage=self.storage) if not perms.bits.read_messages: # ignore requests to unknown channels @@ -1325,9 +1281,7 @@ async def _check_conns(self, user_id): # TODO why is this inneficient? states = self.app.state_manager.user_states(user_id) if not any(s.ws for s in states): - await self.app.presence.dispatch_pres( - user_id, BasePresence(status="offline") - ) + await self.app.presence.dispatch_pres(user_id, BasePresence(status="offline")) async def run(self): """Wrap :meth:`listen_messages` inside diff --git a/litecord/images.py b/litecord/images.py index 09a2c3bc..754f4483 100644 --- a/litecord/images.py +++ b/litecord/images.py @@ -85,9 +85,7 @@ def as_path(self) -> Optional[str]: return None ext = get_ext(self.mime) - return str( - IMAGE_FOLDER / f"{self.fs_hash if self.icon_hash else self.key}.{ext}" - ) + return str(IMAGE_FOLDER / f"{self.fs_hash if self.icon_hash else self.key}.{ext}") @property def as_pathlib(self) -> Optional[Path]: @@ -180,9 +178,7 @@ def parse_data_uri(string) -> tuple: given_mime = "image/png" elif raw_data[0:3] == b"\xff\xd8\xff" or raw_data[6:10] in (b"JFIF", b"Exif"): given_mime = "image/jpeg" - elif raw_data.startswith( - (b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61") - ): + elif raw_data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")): given_mime = "image/gif" elif raw_data.startswith(b"RIFF") and raw_data[8:12] == b"WEBP": given_mime = "image/webp" @@ -318,9 +314,7 @@ async def _convert_ext(self, icon: Icon, target: str): target_mime = get_mime(target) log.info("converting from {} to {}", icon.mime, target_mime) - target_path = ( - IMAGE_FOLDER / f"{icon.fs_hash if icon.fs_hash else icon.key}.{target}" - ) + target_path = IMAGE_FOLDER / f"{icon.fs_hash if icon.fs_hash else icon.key}.{target}" if target_path.exists(): return Icon(icon.key, icon.icon_hash, target_mime) @@ -449,7 +443,7 @@ async def put(self, scope: str, key: str, b64_data: str, **kwargs) -> Icon: """ DELETE FROM icons WHERE key=$1 """, - str(key) + str(key), ) await self.storage.db.execute( """ @@ -463,10 +457,7 @@ async def put(self, scope: str, key: str, b64_data: str, **kwargs) -> Icon: ) # write it off to fs - icon_path = ( - IMAGE_FOLDER - / f"{icon_hash.split('.')[-1] if icon_hash else key}.{extension}" - ) + icon_path = IMAGE_FOLDER / f"{icon_hash.split('.')[-1] if icon_hash else key}.{extension}" if not icon_path.exists(): icon_path.write_bytes(raw_data) @@ -532,9 +523,7 @@ async def update(self, scope: str, key: str, new_icon_data: str, **kwargs) -> Ic """, old_icon.fs_hash, ) - if ( - hits and len(hits) <= 1 - ): # if we have more than one hit, we can't delete it + if hits and len(hits) <= 1: # if we have more than one hit, we can't delete it await self.delete(old_icon) return icon diff --git a/litecord/json.py b/litecord/json.py index 3170b0fc..81a98441 100644 --- a/litecord/json.py +++ b/litecord/json.py @@ -33,12 +33,12 @@ def default(self, value: Any): if isinstance(value, (Decimal, UUID)): return str(value) + if hasattr(value, "to_json"): + return value.to_json() + if is_dataclass(value): return asdict(value) - if hasattr(value, "to_json"): - return value.to_json - return super().default(value) diff --git a/litecord/models.py b/litecord/models.py new file mode 100644 index 00000000..1293af22 --- /dev/null +++ b/litecord/models.py @@ -0,0 +1,75 @@ +from datetime import datetime, date +from typing import Optional +from dataclasses import dataclass, asdict, field + +from litecord.enums import PremiumType + + +@dataclass +class PartialUser: + id: int + username: str + discriminator: str + avatar: Optional[str] + avatar_decoration: Optional[str] + flags: int + bot: bool + system: bool + + def to_json(self): + json = asdict(self) + json["id"] = str(self.id) + json["public_flags"] = json.pop("flags") + return json + + +@dataclass +class User(PartialUser): + banner: Optional[str] + bio: str + accent_color: Optional[str] + pronouns: str + theme_colors: Optional[str] + premium_since: Optional[datetime] + premium_type: Optional[int] + email: Optional[str] + verified: bool + mfa_enabled: bool + date_of_birth: Optional[date] + phone: Optional[str] + + def to_json(self, secure=True): + json = super().to_json() + json["flags"] = json["public_flags"] + json["premium"] = json.pop("premium_since") is not None + json["banner_color"] = hex(json["accent_color"]).replace("0x", "#") if json["accent_color"] else None + + # dob is never to be sent, its only used for nsfw_allowed + dob = json.pop("date_of_birth") + + if secure: + json["desktop"] = json["mobile"] = False + json["phone"] = json["phone"] if json["phone"] else None + + today = date.today() + + json["nsfw_allowed"] = ( + ((today.year - dob.year - ((today.month, today.day) < (dob.month, dob.day))) >= 18) if dob else True + ) + + else: + for field in ("email", "verified", "mfa_enabled", "phone", "premium_type"): + json.pop(field) + return json + + @property + def nsfw_allowed(self): + if not self.date_of_birth: + return True + today = date.today() + + return ( + today.year + - self.date_of_birth.year + - ((today.month, today.day) < (self.date_of_birth.month, self.date_of_birth.day)) + ) >= 18 diff --git a/litecord/permissions.py b/litecord/permissions.py index faf31c8c..41737fb4 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -210,9 +210,7 @@ def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions: return Permissions(result) -def overwrite_find_mix( - perms: Permissions, overwrites: dict, target_id: int -) -> Permissions: +def overwrite_find_mix(perms: Permissions, overwrites: dict, target_id: int) -> Permissions: """Mix a given permission with a given overwrite. Returns the given permission if an overwrite is not found. @@ -240,9 +238,7 @@ def overwrite_find_mix( return perms -async def role_permissions( - guild_id: int, role_id: int, channel_id: int, storage=None -) -> Permissions: +async def role_permissions(guild_id: int, role_id: int, channel_id: int, storage=None) -> Permissions: """Get the permissions for a role, in relation to a channel""" if not storage: storage = app.storage @@ -332,6 +328,4 @@ async def get_permissions(member_id: int, channel_id, *, storage=None) -> Permis base_perms = await base_permissions(member_id, guild_id, storage) - return await compute_overwrites( - base_perms, member_id, channel_id, guild_id, storage - ) + return await compute_overwrites(base_perms, member_id, channel_id, guild_id, storage) diff --git a/litecord/presence.py b/litecord/presence.py index 47a10341..2f3bb092 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from litecord.typing_hax import app + from litecord.storage import Storage else: from quart import current_app as app @@ -120,13 +121,11 @@ class PresenceManager: """ def __init__(self, app): - self.storage = app.storage + self.storage: Storage = app.storage self.user_storage = app.user_storage self.state_manager = app.state_manager - async def guild_presences( - self, members: dict, guild_id: int - ) -> List[Dict[Any, str]]: + async def guild_presences(self, members: dict, guild_id: int) -> List[Dict[Any, str]]: """Fetch all presences in a guild.""" # this works via fetching all connected GatewayState on a guild # then fetching its respective member and merging that info with @@ -149,9 +148,7 @@ async def guild_presences( return presences - async def dispatch_guild_pres( - self, guild_id: int, user_id: int, presence: BasePresence - ): + async def dispatch_guild_pres(self, guild_id: int, user_id: int, presence: BasePresence): """Dispatch a Presence update to an entire guild.""" member = await self.storage.get_member(guild_id, user_id) @@ -209,9 +206,7 @@ def _session_check(session_id): # everyone not in lazy guild mode # gets a PRESENCE_UPDATE - await app.dispatcher.guild.dispatch_filter( - guild_id, _session_check, ("PRESENCE_UPDATE", event_payload) - ) + await app.dispatcher.guild.dispatch_filter(guild_id, _session_check, ("PRESENCE_UPDATE", event_payload)) return in_lazy @@ -226,9 +221,7 @@ async def dispatch_friends_pres(self, user_id: int, presence: BasePresence) -> N ("PRESENCE_UPDATE", {**presence.partial_dict, **{"user": user}}), ) - async def dispatch_friends_pres_filter( - self, user: dict, filter_function, presence: BasePresence - ): + async def dispatch_friends_pres_filter(self, user: dict, filter_function, presence: BasePresence): """ Same as dispatch_friends_pres but passes a filter function Takes in a whole public user object instead of a user id @@ -295,6 +288,6 @@ async def friend_presences(self, friends: Iterable[dict]) -> List[Presence]: res = [] for user in friends: presence = self.fetch_friend_presence(int(user["id"])) - res.append({**presence.partial_dict, 'user': user}) + res.append({**presence.partial_dict, "user": user}) return res diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index a307b30b..c4bf9f25 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -29,7 +29,7 @@ from litecord.typing_hax import app else: from quart import current_app as app - + log = Logger(__name__) @@ -69,8 +69,6 @@ async def _dispatch(session_id: str) -> None: await asyncio.gather(*(_dispatch(sid) for sid in session_ids)) - log.info( - "Dispatched chan={} {!r} to {} states", channel_id, event[0], len(sessions) - ) + log.info("Dispatched chan={} {!r} to {} states", channel_id, event[0], len(sessions)) return sessions diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index f68c7292..aa94ad3c 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -82,9 +82,7 @@ async def clear(self, key: K) -> None: """Clear a key from the backend.""" ... - async def dispatch_filter( - self, key: K, filter_function: Callable[[K], bool], event: EventType - ) -> List[str]: + async def dispatch_filter(self, key: K, filter_function: Callable[[K], bool], event: EventType) -> List[str]: """Selectively dispatch to the list of subscribers. Function must return a list of separate identifiers for composability. diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 7042893f..71fe1dc4 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -56,9 +56,7 @@ def can_dispatch(event_type, event_data, state) -> bool: class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]): """Guild backend for Pub/Sub.""" - async def sub_user( - self, guild_id: int, user_id: int - ) -> Tuple[List[GatewayState], List[int]]: + async def sub_user(self, guild_id: int, user_id: int) -> Tuple[List[GatewayState], List[int]]: states = app.state_manager.fetch_states(user_id, guild_id) asyncio.gather(*(self.sub(guild_id, state.session_id) for state in states)) @@ -68,25 +66,23 @@ async def sub_user( guild_chan_ids = await app.storage.get_channel_ids(guild_id) channel_ids = [] + async def sub_channel(channel_id): perms = await get_permissions(user_id, channel_id) if perms.bits.read_messages: channel_ids.append(channel_id) + await asyncio.gather(*(sub_channel(chan_id) for chan_id in guild_chan_ids)) return states, channel_ids - async def unsub_user( - self, guild_id: int, user_id: int - ) -> Tuple[List[GatewayState], List[int]]: + async def unsub_user(self, guild_id: int, user_id: int) -> Tuple[List[GatewayState], List[int]]: states = app.state_manager.fetch_states(user_id, guild_id) asyncio.gather(*(self.unsub(guild_id, state.session_id) for state in states)) guild_chan_ids = await app.storage.get_channel_ids(guild_id) return states, guild_chan_ids - async def dispatch_filter( - self, guild_id: int, filter_function, event: GatewayEvent - ): + async def dispatch_filter(self, guild_id: int, filter_function, event: GatewayEvent): session_ids = self.state[guild_id] sessions: List[str] = [] event_type, event_data = event diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 8cbb742d..1c4d8fd6 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -28,7 +28,18 @@ import asyncio from collections import defaultdict -from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set, TYPE_CHECKING +from typing import ( + Any, + List, + Dict, + Union, + Optional, + Iterable, + Iterator, + Tuple, + Set, + TYPE_CHECKING, +) from dataclasses import dataclass, asdict, field from logbook import Logger @@ -119,9 +130,7 @@ def __bool__(self): list_dict = asdict(self) # ignore the bool status of overwrites - return all( - bool(list_dict[k]) for k in ("groups", "data", "presences", "members") - ) + return all(bool(list_dict[k]) for k in ("groups", "data", "presences", "members")) def __iter__(self): """Iterate over all groups in the correct order. @@ -224,9 +233,7 @@ async def everyone_allow(gml) -> bool: If the role can't access the list, then the list keeps its list ID. """ - everyone_perms = await role_permissions( - gml.guild_id, gml.guild_id, gml.channel_id, storage=gml.storage - ) + everyone_perms = await role_permissions(gml.guild_id, gml.guild_id, gml.channel_id, storage=gml.storage) return bool(everyone_perms.bits.read_messages) @@ -370,9 +377,7 @@ def _can_read_chan(self, group: GroupInfo) -> bool: # then the final perms for that role if # any overwrite exists in the channel - final_perms = overwrite_find_mix( - role_perms, self.list.overwrites, int(group.gid) - ) + final_perms = overwrite_find_mix(role_perms, self.list.overwrites, int(group.gid)) # update the group's permissions # with the mixed ones @@ -406,9 +411,7 @@ async def _get_role_groups(self) -> List[GroupInfo]: ) hoisted = [ - GroupInfo( - row["id"], row["name"], row["position"], Permissions(row["permissions"]) - ) + GroupInfo(row["id"], row["name"], row["position"], Permissions(row["permissions"])) for row in roledata if row["hoist"] ] @@ -441,20 +444,14 @@ async def _get_group_for_member( # get the member's permissions relative to the channel # (accounting for channel overwrites) - member_perms = await get_permissions( - member_id, self.channel_id, storage=self.storage - ) + member_perms = await get_permissions(member_id, self.channel_id, storage=self.storage) if not member_perms.bits.read_messages: return None # if the member is offline, we # default give them the offline group. - group_id = ( - "offline" - if status == "offline" - else self._calc_member_group(member_roles, status) - ) + group_id = "offline" if status == "offline" else self._calc_member_group(member_roles, status) return group_id @@ -464,9 +461,7 @@ async def _list_fill_groups(self, members: List[dict]): member_id = int(member["user"]["id"]) presence = self.list.presences[member_id] - group_id = await self._get_group_for_member( - member_id, presence["roles"], presence["status"] - ) + group_id = await self._get_group_for_member(member_id, presence["roles"], presence["status"]) # skip members that don't have any group assigned. # (members without read messages) @@ -604,9 +599,7 @@ def _get_state(self, session_id: str) -> Optional[GatewayState]: except KeyError: return None - async def _dispatch_sess( - self, session_ids: Iterable[str], operations: List[Operation] - ): + async def _dispatch_sess(self, session_ids: Iterable[str], operations: List[Operation]): """Dispatch a GUILD_MEMBER_LIST_UPDATE to the given session ids.""" @@ -621,9 +614,7 @@ async def _dispatch_sess( payload = { "id": self.list_id, "guild_id": str(self.guild_id), - "groups": [ - {"id": str(group.gid), "count": count} for group, count in groups - ], + "groups": [{"id": str(group.gid), "count": count} for group, count in groups], "ops": [operation.to_dict for operation in operations], "member_count": member_count, "online_count": member_count - offline_count, @@ -661,11 +652,7 @@ async def _resync(self, session_ids: List[str], item_index: int) -> List[str]: try: # get the only range where the group is in - role_range = next( - (r_min, r_max) - for r_min, r_max in ranges - if r_min <= item_index <= r_max - ) + role_range = next((r_min, r_max) for r_min, r_max in ranges if r_min <= item_index <= r_max) except StopIteration: log.debug( "ignoring sess_id={}, no range for item {}, {}", @@ -712,9 +699,7 @@ async def shard_query(self, session_id: str, ranges: list): # if everyone can read the channel, # we direct the request to the 'everyone' gml instance # instead of the current one. - everyone_perms = await role_permissions( - self.guild_id, self.guild_id, self.channel_id, storage=self.storage - ) + everyone_perms = await role_permissions(self.guild_id, self.guild_id, self.channel_id, storage=self.storage) if everyone_perms.bits.read_messages and list_id != "everyone": everyone_gml = await app.lazy_guild.get_gml(self.guild_id) @@ -733,11 +718,7 @@ async def shard_query(self, session_id: str, ranges: list): self.state[session_id].add((start, end)) - ops.append( - Operation( - "SYNC", {"range": [start, end], "items": self.items[start:end]} - ) - ) + ops.append(Operation("SYNC", {"range": [start, end], "items": self.items[start:end]})) # send SYNCs to the state that requested await self._dispatch_sess([session_id], ops) @@ -788,9 +769,7 @@ def _is_subbed(self, item_index, session_id: str) -> bool: def _get_subs(self, item_index: int) -> Iterable[str]: """Get the list of subscribed states to a given item.""" - return filter( - lambda sess_id: self._is_subbed(item_index, sess_id), self.state.keys() - ) + return filter(lambda sess_id: self._is_subbed(item_index, sess_id), self.state.keys()) async def _pres_update_simple(self, user_id: int): """Handler for simple presence updates. @@ -810,13 +789,9 @@ async def _pres_update_simple(self, user_id: int): # simple update means we just give an UPDATE # operation - return await self._dispatch_sess( - session_ids, [Operation("UPDATE", {"index": item_index, "item": item})] - ) + return await self._dispatch_sess(session_ids, [Operation("UPDATE", {"index": item_index, "item": item})]) - async def _pres_update_complex( - self, user_id: int, old_group: GroupID, rel_index: int, new_group: GroupID - ): + async def _pres_update_complex(self, user_id: int, old_group: GroupID, rel_index: int, new_group: GroupID): """Move a member between groups. Parameters @@ -905,13 +880,9 @@ async def _pres_update_complex( # ) # merge both results together - return await self._resync(session_ids_old, old_user_index) + await self._resync( - session_ids_new, new_user_index - ) + return await self._resync(session_ids_old, old_user_index) + await self._resync(session_ids_new, new_user_index) - async def _pres_update_remove( - self, user_id: int, old_group: GroupID, old_index: int - ): + async def _pres_update_remove(self, user_id: int, old_group: GroupID, old_index: int): log.debug( "removal update: uid={} old={} rel_idx={} new={}", user_id, @@ -946,9 +917,7 @@ async def new_member(self, user_id: int): self.list.members[user_id] = member # find a group for the newcomer - group_id = await self._get_group_for_member( - user_id, member["roles"], pres["status"] - ) + group_id = await self._get_group_for_member(user_id, member["roles"], pres["status"]) if group_id is None: log.warning("lazy: not adding uid {}, no group", user_id) @@ -1012,9 +981,7 @@ async def remove_member(self, user_id: int): log.warning("lazy: unknown member uid {}", user_id) return - group_id = await self._get_group_for_member( - user_id, member["roles"], pres["status"] - ) + group_id = await self._get_group_for_member(user_id, member["roles"], pres["status"]) if not group_id: log.warning("lazy: unknown group uid {}", user_id) @@ -1139,9 +1106,7 @@ async def pres_update(self, user_id: int, partial_presence: Presence): # channel. return await self._pres_update_remove(user_id, old_group, old_index) else: - return await self._pres_update_complex( - user_id, old_group, old_index, new_group - ) + return await self._pres_update_complex(user_id, old_group, old_index, new_group) async def new_role(self, role: dict): """Add a new role to the list. @@ -1154,9 +1119,7 @@ async def new_role(self, role: dict): group_id = int(role["id"]) - new_group = GroupInfo( - group_id, role["name"], role["position"], Permissions(role["permissions"]) - ) + new_group = GroupInfo(group_id, role["name"], role["position"], Permissions(role["permissions"])) # check if new role has good perms await self._fetch_overwrites() @@ -1247,14 +1210,10 @@ async def role_pos_update(self, role: dict): # TODO: maybe this can be more efficient? # we could self.list.groups.insert... but I don't know. # I'm taking the safe route right now by using sorted() - new_groups = sorted( - self.list.groups, key=lambda group: group.position, reverse=True - ) + new_groups = sorted(self.list.groups, key=lambda group: group.position, reverse=True) log.debug( - "resorted groups from role pos upd " - "rid={} rpos={} (gid={}, cid={}) " - "res={}", + "resorted groups from role pos upd " "rid={} rpos={} (gid={}, cid={}) " "res={}", role_id, group.position, self.guild_id, @@ -1265,9 +1224,7 @@ async def role_pos_update(self, role: dict): self.list.groups = new_groups new_index = self._get_group_item_index(role_id) - return await self._resync(old_sessions, old_index) + await self._resync_by_item( - new_index - ) + return await self._resync(old_sessions, old_index) + await self._resync_by_item(new_index) async def role_update(self, role: dict): """Update a role. @@ -1318,9 +1275,7 @@ async def role_update(self, role: dict): return await self.role_delete(role_id) if not role["hoist"]: - log.debug( - "role_update promote to role_delete " "call rid={} (no hoist)", role_id - ) + log.debug("role_update promote to role_delete " "call rid={} (no hoist)", role_id) return await self.role_delete(role_id) async def role_delete(self, role_id: int, deleted: bool = False): @@ -1345,14 +1300,10 @@ async def role_delete(self, role_id: int, deleted: bool = False): # using a filter object would cause problems # as we only resync AFTER we delete the group - sess_ids_resync = ( - list(self._get_subs(role_item_index)) if role_item_index is not None else [] - ) + sess_ids_resync = list(self._get_subs(role_item_index)) if role_item_index is not None else [] # remove the group info off the list - groups_index = index_by_func( - lambda group: group.gid == role_id, self.list.groups - ) + groups_index = index_by_func(lambda group: group.gid == role_id, self.list.groups) if groups_index is not None: del self.list.groups[groups_index] diff --git a/litecord/pubsub/member.py b/litecord/pubsub/member.py index d477fc30..e8b0fcad 100644 --- a/litecord/pubsub/member.py +++ b/litecord/pubsub/member.py @@ -26,9 +26,8 @@ else: from quart import current_app as app, request -async def dispatch_member( - guild_id: int, user_id: int, event: GatewayEvent -) -> List[str]: + +async def dispatch_member(guild_id: int, user_id: int, event: GatewayEvent) -> List[str]: states = app.state_manager.fetch_states(user_id, guild_id) # if no states were found, we should unsub the user from the guild diff --git a/litecord/pubsub/user.py b/litecord/pubsub/user.py index 5c0a25d5..c677f4a4 100644 --- a/litecord/pubsub/user.py +++ b/litecord/pubsub/user.py @@ -27,6 +27,7 @@ else: from quart import current_app as app + async def dispatch_user_filter( user_id: int, filter_func: Optional[Callable[[str], bool]], event_data: GatewayEvent ) -> List[str]: diff --git a/litecord/pubsub/utils.py b/litecord/pubsub/utils.py index 614553df..f1aadf19 100644 --- a/litecord/pubsub/utils.py +++ b/litecord/pubsub/utils.py @@ -24,9 +24,7 @@ log = logging.getLogger(__name__) -async def send_event_to_states( - states: List[GatewayState], event_data: Tuple[str, Any] -) -> List[str]: +async def send_event_to_states(states: List[GatewayState], event_data: Tuple[str, Any]) -> List[str]: """Dispatch an event to a list of states.""" res = [] diff --git a/litecord/ratelimits/handler.py b/litecord/ratelimits/handler.py index 1860e5bb..93a213cc 100644 --- a/litecord/ratelimits/handler.py +++ b/litecord/ratelimits/handler.py @@ -27,6 +27,7 @@ else: from quart import current_app as app, request + async def _check_bucket(bucket): retry_after = bucket.update_rate_limit() @@ -35,9 +36,7 @@ async def _check_bucket(bucket): if retry_after: request.retry_after = retry_after - raise Ratelimited( - **{"retry_after": int(retry_after * 1000), "global": request.bucket_global} - ) + raise Ratelimited(**{"retry_after": int(retry_after * 1000), "global": request.bucket_global}) async def _handle_global(ratelimit): diff --git a/litecord/schemas.py b/litecord/schemas.py index 69caac98..2da31c24 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -181,9 +181,7 @@ def _validate_type_rgb_str_color(self, value: str) -> bool: else: return True - def _validate_type_recipients( - self, value: Union[List[Union[int, str]], Union[int, str]] - ): + def _validate_type_recipients(self, value: Union[List[Union[int, str]], Union[int, str]]): return ( all(self._validate_type_snowflake(v) for v in value) if isinstance(value, list) @@ -289,9 +287,7 @@ def _format_message(self, field, error): info = self.messages.get(error.code, self.messages[0x00]) return { "code": info["code"].format(constraint=error.constraint).upper(), - "message": info["message"].format( - *error.info, constraint=error.constraint, field=field, value=error.value - ), + "message": info["message"].format(*error.info, constraint=error.constraint, field=field, value=error.value), } diff --git a/litecord/storage.py b/litecord/storage.py index 99e49803..ab508750 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -18,7 +18,19 @@ """ import asyncio -from typing import List, Dict, Any, Optional, TypedDict, cast, Iterable, TYPE_CHECKING +from typing import ( + List, + Dict, + Any, + Optional, + TypedDict, + cast, + Iterable, + TYPE_CHECKING, + Union, + overload, + Literal, +) from xml.etree.ElementInclude import include import aiohttp @@ -35,7 +47,7 @@ partial_emoji, PartialEmoji, ) - +from litecord.models import PartialUser, User from litecord.types import timestamp_ from litecord.json import pg_set_json from litecord.presence import PresenceManager @@ -108,30 +120,17 @@ async def execute_with_json(self, query: str, *args) -> str: async def parse_user(self, duser: dict, secure: bool) -> dict: duser["premium"] = duser.pop("premium_since") is not None duser["public_flags"] = duser["flags"] - duser["banner_color"] = ( - hex(duser["accent_color"]).replace("0x", "#") - if duser["accent_color"] - else None - ) + duser["banner_color"] = hex(duser["accent_color"]).replace("0x", "#") if duser["accent_color"] else None if secure: - duser["desktop"] = True + duser["desktop"] = False duser["mobile"] = False duser["phone"] = duser["phone"] if duser["phone"] else None today = date.today() born = duser.pop("date_of_birth") duser["nsfw_allowed"] = ( - ( - ( - today.year - - born.year - - ((today.month, today.day) < (born.month, born.day)) - ) - >= 18 - ) - if born - else True + ((today.year - born.year - ((today.month, today.day) < (born.month, born.day))) >= 18) if born else True ) plan_id = await self.db.fetchval( @@ -148,12 +147,25 @@ async def parse_user(self, duser: dict, secure: bool) -> dict: return duser - async def get_user(self, user_id, secure: bool = False) -> Optional[Dict[str, Any]]: + @overload + async def get_user(self, user_id: int, full: Literal[True] = True) -> Optional[User]: + ... + + @overload + async def get_user(self, user_id: int, full: Literal[False] = False) -> Optional[PartialUser]: + ... + + async def get_user(self, user_id: int, full: bool = False) -> Union[Optional[User], Optional[PartialUser]]: """Get a single user payload.""" - user_id = int(user_id) + + # Look for user in cache IF full user isn't needed + if not full: + cache_user = self.app.cache.users.get(user_id, None) + if cache_user is not None: + return cache_user fields = [ - "id::text", + "id", "username", "discriminator", "avatar", @@ -167,13 +179,13 @@ async def get_user(self, user_id, secure: bool = False) -> Optional[Dict[str, An "pronouns", "avatar_decoration", "theme_colors", + "email", + "verified", + "mfa_enabled", + "date_of_birth", + "phone", ] - if secure: - fields.extend( - ["email", "verified", "mfa_enabled", "date_of_birth", "phone"] - ) - user_row = await self.db.fetchrow( f""" SELECT {','.join(fields)} @@ -186,7 +198,17 @@ async def get_user(self, user_id, secure: bool = False) -> Optional[Dict[str, An if not user_row: return None - return await self.parse_user(dict(user_row), secure) + plan_id = await self.db.fetchval( + """ + SELECT payment_gateway_plan_id + FROM user_subscriptions + WHERE status = 1 + AND user_id = $1 + """, + int(user_id), + ) + + return User(**user_row, premium_type=PLAN_ID_TO_TYPE.get(cast(str, plan_id))) async def get_users( self, @@ -194,7 +216,7 @@ async def get_users( secure: bool = False, extra_clause: str = "", where_clause: str = "WHERE id = ANY($1::bigint[])", - args: Optional[List[Any]] = None, + args: Optional[Iterable[Any]] = None, ) -> List[dict]: """Get many user payloads.""" fields = [ @@ -215,9 +237,7 @@ async def get_users( ] if secure: - fields.extend( - ["email", "verified", "mfa_enabled", "date_of_birth", "phone"] - ) + fields.extend(["email", "verified", "mfa_enabled", "date_of_birth", "phone"]) users_rows = await self.db.fetch( f""" @@ -228,9 +248,7 @@ async def get_users( *(args or [user_ids if user_ids else []]), ) - return await asyncio.gather( - *(self.parse_user(dict(user_row), secure) for user_row in users_rows) - ) + return await asyncio.gather(*(self.parse_user(dict(user_row), secure) for user_row in users_rows)) async def search_user(self, username: str, discriminator: str) -> int: """Search a user""" @@ -298,23 +316,19 @@ async def parse_guild( # hardcoding these since: # - we aren't discord # - the limit for guilds is unknown and heavily dependant on the hardware - drow["max_presences"] = drow["max_members"] = drow[ - "max_video_channel_users" - ] = drow["max_stage_video_channel_users"] = 1000000 + drow["max_presences"] = drow["max_members"] = drow["max_video_channel_users"] = drow[ + "max_stage_video_channel_users" + ] = 1000000 # TODO drow["preferred_locale"] = "en-US" - drow["guild_scheduled_events"] = drow["embedded_activities"] = drow[ - "connections" - ] = drow["stickers"] = [] + drow["guild_scheduled_events"] = drow["embedded_activities"] = drow["connections"] = drow["stickers"] = [] if full: return {**drow, **await self.get_guild_extra(guild_id, user_id, large)} return drow - async def get_guild( - self, guild_id: int, user_id: Optional[int] = None - ) -> Optional[Dict]: + async def get_guild(self, guild_id: int, user_id: Optional[int] = None) -> Optional[Dict]: """Get guild payload.""" unavailable = self.app.guild_store.get(guild_id, "unavailable", False) if unavailable: @@ -366,9 +380,7 @@ async def get_guilds( *(args or [guild_ids if guild_ids else []]), ) - return await asyncio.gather( - *(self.parse_guild(dict(row), user_id, full, large) for row in rows) - ) + return await asyncio.gather(*(self.parse_guild(dict(row), user_id, full, large) for row in rows)) async def get_member_role_ids(self, guild_id: int, member_id: int) -> List[int]: """Get a list of role IDs that are on a member.""" @@ -402,9 +414,7 @@ async def get_member_role_ids(self, guild_id: int, member_id: int) -> List[int]: return roles - async def get_member( - self, guild_id, member_id, with_user: bool = True - ) -> Optional[Dict[str, Any]]: + async def get_member(self, guild_id, member_id, with_user: bool = True) -> Optional[Dict[str, Any]]: row = await self.db.fetchrow( """ SELECT user_id, nickname AS nick, joined_at, @@ -445,9 +455,7 @@ async def get_member( return drow - async def get_member_multi( - self, guild_id: int, user_ids: List[int] - ) -> List[Dict[str, Any]]: + async def get_member_multi(self, guild_id: int, user_ids: List[int]) -> List[Dict[str, Any]]: """Get member information about multiple users in a guild.""" members = [] @@ -460,9 +468,7 @@ async def get_member_multi( return members - async def get_members( - self, guild_id: int, with_user: bool = True - ) -> Dict[int, Dict[str, Any]]: + async def get_members(self, guild_id: int, with_user: bool = True) -> Dict[int, Dict[str, Any]]: """Get member information on a guild.""" members_basic = await self.db.fetch( """ @@ -591,9 +597,7 @@ async def get_chan_type(self, channel_id: int) -> Optional[int]: channel_id, ) - async def chan_overwrites( - self, channel_id: int, safe: bool = True - ) -> List[Dict[str, Any]]: + async def chan_overwrites(self, channel_id: int, safe: bool = True) -> List[Dict[str, Any]]: overwrite_rows = await self.db.fetch( f""" SELECT target_type, target_role, target_user, allow{'::text' if safe else ''}, deny{'::text' if safe else ''} @@ -634,13 +638,11 @@ async def gdm_recipient_ids(self, channel_id: int) -> List[int]: return [r["member_id"] for r in user_ids] - async def _gdm_recipients( - self, channel_id: int, reference_id: Optional[int] = None - ) -> List[Dict]: + async def _gdm_recipients(self, channel_id: int, reference_id: Optional[int] = None) -> List[PartialUser]: """Get the list of users that are recipients of the given Group DM.""" recipients = await self.gdm_recipient_ids(channel_id) - res = [] + res: List[PartialUser] = [] for user_id in recipients: if user_id == reference_id: @@ -731,9 +733,7 @@ async def get_channel(self, channel_id: int, **kwargs) -> Optional[Dict[str, Any drow["last_message_id"] = await self.chan_last_message_str(channel_id) return drow - raise RuntimeError( - f"Data Inconsistency: Channel type {ctype} is not properly handled" - ) + raise RuntimeError(f"Data Inconsistency: Channel type {ctype} is not properly handled") async def get_channel_ids(self, guild_id: int) -> List[int]: """Get all channel IDs in a guild.""" @@ -782,9 +782,7 @@ async def get_channel_data(self, guild_id) -> List[Dict]: return channels - async def get_role( - self, role_id: int, guild_id: Optional[int] = None - ) -> Optional[Dict[str, Any]]: + async def get_role(self, role_id: int, guild_id: Optional[int] = None) -> Optional[Dict[str, Any]]: """get a single role's information.""" guild_field = "AND guild_id = $2" if guild_id else "" @@ -826,9 +824,7 @@ async def get_role_data(self, guild_id: int) -> List[Dict[str, Any]]: return list(map(dict, roledata)) - async def guild_voice_states( - self, guild_id: int, user_id=None - ) -> List[Dict[str, Any]]: + async def guild_voice_states(self, guild_id: int, user_id=None) -> List[Dict[str, Any]]: """Get a list of voice states for the given guild.""" channel_ids = await self.get_channel_ids(guild_id) if not user_id: @@ -848,9 +844,7 @@ async def guild_voice_states( return res - async def get_guild_extra( - self, guild_id: int, user_id: Optional[int] = None, large: Optional[int] = None - ) -> Dict: + async def get_guild_extra(self, guild_id: int, user_id: Optional[int] = None, large: Optional[int] = None) -> Dict: """Get extra information about a guild.""" res = {} @@ -1004,14 +998,10 @@ async def parse_message( res["type"] = res.pop("message_type") res["content"] = res["content"] or "" res["pinned"] = bool(res["pinned"]) - res["mention_roles"] = ( - [str(r) for r in res["mention_roles"]] if res["mention_roles"] else [] - ) + res["mention_roles"] = [str(r) for r in res["mention_roles"]] if res["mention_roles"] else [] guild_id = res["guild_id"] - is_crosspost = ( - res["flags"] & MessageFlags.is_crosspost == MessageFlags.is_crosspost - ) + is_crosspost = res["flags"] & MessageFlags.is_crosspost == MessageFlags.is_crosspost attachments = list(res["attachments"]) if res["attachments"] else [] reactions = list(res["reactions"]) if res["reactions"] else [] @@ -1020,9 +1010,7 @@ async def parse_message( res["guild_id"] = str(guild_id) if guild_id else None if res.get("message_reference") and not is_crosspost and include_member: - message = await self.get_message( - int(res["message_reference"]["message_id"]), user_id, include_member - ) + message = await self.get_message(int(res["message_reference"]["message_id"]), user_id, include_member) res["referenced_message"] = message async def _get_user(user_id): @@ -1031,6 +1019,7 @@ async def _get_user(user_id): except KeyError: user = await self.get_user(user_id) if include_member and user and guild_id: + user = user.to_json() member = await self.get_member(guild_id, user_id, False) if member: user["member"] = member @@ -1102,11 +1091,7 @@ async def _get_user(user_id): # TODO: content_type proto = "https" if self.app.config["IS_SSL"] else "http" main_url = self.app.config["MAIN_URL"] - attachment["url"] = ( - f"{proto}://{main_url}/attachments/" - f"{a_channel_id}/{a_message_id}/" - f"{filename}" - ) + attachment["url"] = f"{proto}://{main_url}/attachments/" f"{a_channel_id}/{a_message_id}/" f"{filename}" attachment["proxy_url"] = attachment["url"] if attachment["height"] is None: attachment.pop("height") @@ -1211,10 +1196,7 @@ async def get_messages( user_cache = {} return await asyncio.gather( - *( - self.parse_message(dict(row), user_id, include_member, user_cache) - for row in rows - ) + *(self.parse_message(dict(row), user_id, include_member, user_cache) for row in rows) ) async def get_invite(self, invite_code: str) -> Optional[Dict]: @@ -1276,11 +1258,7 @@ async def get_invite(self, invite_code: str) -> Optional[Dict]: if chan is None: return None - dinv["channel"] = ( - {"id": chan["id"], "name": chan["name"], "type": chan["type"]} - if chan - else None - ) + dinv["channel"] = {"id": chan["id"], "name": chan["name"], "type": chan["type"]} if chan else None dinv["type"] = 0 if guild else (1 if chan else 2) @@ -1289,9 +1267,7 @@ async def get_invite(self, invite_code: str) -> Optional[Dict]: return dinv - async def get_invite_extra( - self, invite_code: str, counts: bool = True, expiry: bool = False - ) -> dict: + async def get_invite_extra(self, invite_code: str, counts: bool = True, expiry: bool = False) -> dict: """Extra information about the invite, such as approximate guild and presence counts.""" data = {} @@ -1319,9 +1295,7 @@ async def get_invite_extra( ) data["expires_at"] = ( - timestamp_(erow["created_at"] + timedelta(seconds=erow["max_age"])) - if erow["max_age"] > 0 - else None + timestamp_(erow["created_at"] + timedelta(seconds=erow["max_age"])) if erow["max_age"] > 0 else None ) return data @@ -1342,12 +1316,9 @@ async def get_invite_metadata(self, invite_code: str) -> Optional[Dict[str, Any] return None dinv = dict(invite) - inviter = await self.get_user(invite["inviter"]) - dinv["inviter"] = inviter + dinv["inviter"] = await self.get_user(invite["inviter"]) dinv["expires_at"] = ( - timestamp_(invite["created_at"] + timedelta(seconds=invite["max_age"])) - if invite["max_age"] > 0 - else None + timestamp_(invite["created_at"] + timedelta(seconds=invite["max_age"])) if invite["max_age"] > 0 else None ) dinv["created_at"] = timestamp_(invite["created_at"]) diff --git a/litecord/system_messages.py b/litecord/system_messages.py index 73b4a3e9..f33c0031 100644 --- a/litecord/system_messages.py +++ b/litecord/system_messages.py @@ -165,9 +165,7 @@ async def _handle_gdm_icon_edit(channel_id, author_id): return new_id -async def send_sys_message( - channel_id: int, m_type: MessageType, *args, **kwargs -) -> int: +async def send_sys_message(channel_id: int, m_type: MessageType, *args, **kwargs) -> int: """Send a system message. The handler for a given message type MUST return an integer, that integer diff --git a/litecord/typing_hax.py b/litecord/typing_hax.py index 6bf0f9d1..154a3141 100644 --- a/litecord/typing_hax.py +++ b/litecord/typing_hax.py @@ -19,6 +19,8 @@ from .voice.manager import VoiceManager from .jobs import JobManager from .errors import BadRequest +from .cache import CacheManager + class Request(_Request): @@ -27,7 +29,7 @@ class Request(_Request): bucket_global: RatelimitBucket retry_after: Optional[int] user_id: Optional[int] - + def on_json_loading_failed(self, error: Exception) -> Any: raise BadRequest(50109) @@ -37,6 +39,7 @@ class LitecordApp(Quart): session: ClientSession db: Pool sched: JobManager + cache: CacheManager winter_factory: SnowflakeFactory loop: AbstractEventLoop @@ -61,7 +64,7 @@ def __init__( ) self.config.from_object(config_path) self.config["MAX_CONTENT_LENGTH"] = 500 * 1024 * 1024 # 500 MB - + def init_managers(self): # Init singleton classes self.session = ClientSession() @@ -78,6 +81,7 @@ def init_managers(self): self.guild_store = GuildMemoryStore() self.lazy_guild = LazyGuildManager() self.voice = VoiceManager(self) + @property def is_debug(self) -> bool: return self.config.get("DEBUG", False) diff --git a/litecord/user_storage.py b/litecord/user_storage.py index 63000757..017d5de5 100644 --- a/litecord/user_storage.py +++ b/litecord/user_storage.py @@ -129,9 +129,7 @@ async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]: ) # only need their ids - incoming_friends = [ - r["user_id"] for r in incoming_friends if r["user_id"] not in mutuals - ] + incoming_friends = [r["user_id"] for r in incoming_friends if r["user_id"] not in mutuals] # only fetch blocks we did, # not fetching the ones people did to us @@ -191,11 +189,7 @@ async def get_friend_ids(self, user_id: int) -> List[int]: """Get all friend IDs for a user.""" rels = await self.get_relationships(user_id) - return [ - int(r["user"]["id"]) - for r in rels - if r["type"] == RelationshipType.FRIEND.value - ] + return [int(r["user"]["id"]) for r in rels if r["type"] == RelationshipType.FRIEND.value] async def get_dms(self, user_id: int) -> List[Dict[str, Any]]: """Get all DM channels for a user, including group DMs. diff --git a/litecord/utils.py b/litecord/utils.py index 76ddbed0..7145295b 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -94,28 +94,17 @@ def mmh3(inp_str: str, seed: int = 0): i = 0 while i < bytecount: - k1 = ( - (key[i] & 0xFF) - | ((key[i + 1] & 0xFF) << 8) - | ((key[i + 2] & 0xFF) << 16) - | ((key[i + 3] & 0xFF) << 24) - ) + k1 = (key[i] & 0xFF) | ((key[i + 1] & 0xFF) << 8) | ((key[i + 2] & 0xFF) << 16) | ((key[i + 3] & 0xFF) << 24) i += 4 - k1 = ( - (((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16)) - ) & 0xFFFFFFFF + k1 = ((((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16))) & 0xFFFFFFFF k1 = (k1 << 15) | (_u(k1) >> 17) - k1 = ( - (((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16)) - ) & 0xFFFFFFFF + k1 = ((((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16))) & 0xFFFFFFFF h1 ^= k1 h1 = (h1 << 13) | (_u(h1) >> 19) - h1b = ( - (((h1 & 0xFFFF) * 5) + ((((_u(h1) >> 16) * 5) & 0xFFFF) << 16)) - ) & 0xFFFFFFFF + h1b = ((((h1 & 0xFFFF) * 5) + ((((_u(h1) >> 16) * 5) & 0xFFFF) << 16))) & 0xFFFFFFFF h1 = ((h1b & 0xFFFF) + 0x6B64) + ((((_u(h1b) >> 16) + 0xE654) & 0xFFFF) << 16) k1 = 0 @@ -139,16 +128,9 @@ def mmh3(inp_str: str, seed: int = 0): h1 ^= len(key) h1 ^= _u(h1) >> 16 - h1 = ( - ((h1 & 0xFFFF) * 0x85EBCA6B) + ((((_u(h1) >> 16) * 0x85EBCA6B) & 0xFFFF) << 16) - ) & 0xFFFFFFFF + h1 = (((h1 & 0xFFFF) * 0x85EBCA6B) + ((((_u(h1) >> 16) * 0x85EBCA6B) & 0xFFFF) << 16)) & 0xFFFFFFFF h1 ^= _u(h1) >> 13 - h1 = ( - ( - ((h1 & 0xFFFF) * 0xC2B2AE35) - + ((((_u(h1) >> 16) * 0xC2B2AE35) & 0xFFFF) << 16) - ) - ) & 0xFFFFFFFF + h1 = ((((h1 & 0xFFFF) * 0xC2B2AE35) + ((((_u(h1) >> 16) * 0xC2B2AE35) & 0xFFFF) << 16))) & 0xFFFFFFFF h1 ^= _u(h1) >> 16 return _u(h1) >> 0 @@ -184,9 +166,7 @@ def maybe_int(val: Any) -> Union[int, Any]: return val -async def maybe_lazy_guild_dispatch( - guild_id: int, event: str, role, force: bool = False -): +async def maybe_lazy_guild_dispatch(guild_id: int, event: str, role, force: bool = False): # sometimes we want to dispatch an event # even if the role isn't hoisted diff --git a/litecord/voice/manager.py b/litecord/voice/manager.py index 693a560f..3c00c4ae 100644 --- a/litecord/voice/manager.py +++ b/litecord/voice/manager.py @@ -17,10 +17,9 @@ """ -from typing import Tuple, Dict, List +from typing import Tuple, Dict, List, TYPE_CHECKING from collections import defaultdict from dataclasses import fields -from quart import current_app as app, request from logbook import Logger @@ -29,6 +28,10 @@ from litecord.voice.state import VoiceState from litecord.voice.lvsp_manager import LVSPManager +if TYPE_CHECKING: + from litecord.typing_hax import app, request, LitecordApp +else: + from quart import current_app as app, request VoiceKey = Tuple[int, int] log = Logger(__name__) @@ -45,7 +48,7 @@ def _construct_state(state_dict: dict) -> VoiceState: class VoiceManager: """Main voice manager class.""" - def __init__(self, app): + def __init__(self, app: LitecordApp): self.app = app # double dict, first key is guild/channel id, second key is user id @@ -72,7 +75,7 @@ async def can_join(self, user_id: int, channel_id: int) -> int: # hacky user_limit but should work, as channels not # in guilds won't have that field. is_full = states >= channel.get("user_limit", 100) - is_bot = (await self.app.storage.get_user(user_id))["bot"] + is_bot = (await self.app.storage.get_user(user_id)).bot is_manager = perms.bits.manage_channels # if the channel is full AND: @@ -181,9 +184,7 @@ async def _start_voice_guild(self, voice_key: VoiceKey, data: dict): channel_id = int(data["channel_id"]) existing_states = self.states[voice_key] - channel_exists = any( - state.channel_id == channel_id for state in existing_states - ) + channel_exists = any(state.channel_id == channel_id for state in existing_states) if not channel_exists: await self._create_ctx_guild(guild_id, channel_id) diff --git a/manage/cmd/migration/command.py b/manage/cmd/migration/command.py index 87ebd0bf..5af05d1f 100644 --- a/manage/cmd/migration/command.py +++ b/manage/cmd/migration/command.py @@ -271,8 +271,6 @@ async def migrate_cmd(app, _args): def setup(subparser): - migrate_parser = subparser.add_parser( - "migrate", help="Run migration tasks", description=migrate_cmd.__doc__ - ) + migrate_parser = subparser.add_parser("migrate", help="Run migration tasks", description=migrate_cmd.__doc__) migrate_parser.set_defaults(func=migrate_cmd) diff --git a/manage/cmd/users.py b/manage/cmd/users.py index e54533de..c79a9e78 100644 --- a/manage/cmd/users.py +++ b/manage/cmd/users.py @@ -278,17 +278,13 @@ def setup(subparser): addbot_parser.add_argument("password", help="Password of the bot") addbot_parser.set_defaults(func=addbot) - setflag_parser = subparser.add_parser( - "setflag", help="Set a flag for a user", description=set_flag.__doc__ - ) + setflag_parser = subparser.add_parser("setflag", help="Set a flag for a user", description=set_flag.__doc__) setflag_parser.add_argument("username", help="Username of the user") setflag_parser.add_argument("discriminator", help="Discriminator of the user") setflag_parser.add_argument("flag_name", help="The flag to set"), setflag_parser.set_defaults(func=set_flag) - unsetflag_parser = subparser.add_parser( - "unsetflag", help="Unset a flag for a user", description=unset_flag.__doc__ - ) + unsetflag_parser = subparser.add_parser("unsetflag", help="Unset a flag for a user", description=unset_flag.__doc__) unsetflag_parser.add_argument("username", help="Username of the user") unsetflag_parser.add_argument("discriminator", help="Discriminator of the user") unsetflag_parser.add_argument("flag_name", help="The flag to unset"), @@ -307,9 +303,7 @@ def setup(subparser): token_parser.add_argument("user_id", help="ID of the user") token_parser.set_defaults(func=generate_bot_token) - set_password_user_parser = subparser.add_parser( - "setpass", help="Set the password of a user" - ) + set_password_user_parser = subparser.add_parser("setpass", help="Set the password of a user") set_password_user_parser.add_argument("username", help="Username of the user") set_password_user_parser.add_argument("discriminator", help="Discriminator of the user") set_password_user_parser.add_argument("password", help="New password for the user") diff --git a/manage/main.py b/manage/main.py index 457673f3..a1ae89ff 100644 --- a/manage/main.py +++ b/manage/main.py @@ -53,11 +53,7 @@ class FakeApp: def make_app(self) -> Quart: app = Quart(__name__) app.config.from_object(self.config) - fields = [ - field - for (field, _val) in inspect.getmembers(self) - if not field.startswith("__") - ] + fields = [field for (field, _val) in inspect.getmembers(self) if not field.startswith("__")] for field in fields: setattr(app, field, getattr(self, field)) diff --git a/pyproject.toml b/pyproject.toml index b08376d3..4ab54178 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ emoji = "<3.0.0" [tool.poetry.dev-dependencies] +[tool.black] +line-length = 120 + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/run.py b/run.py index 75d57290..181734c5 100755 --- a/run.py +++ b/run.py @@ -120,7 +120,7 @@ def make_app(): app = LitecordApp(__name__) - + if app.is_debug: log.info("on debug") handler.level = logbook.DEBUG @@ -282,6 +282,7 @@ async def init_app_db(app_: LitecordApp): app_.sched = JobManager(context_func=app.app_context) app.init_managers() + async def api_index(app_: LitecordApp): to_find = {} found = [] @@ -312,7 +313,7 @@ async def api_index(app_: LitecordApp): path = path.replace("peer.id", "user.id") methods = rule.methods - if not methods: + if not methods: continue for method in methods: pathname = to_find.get((path, method)) @@ -377,9 +378,7 @@ async def app_before_serving(): # start gateway websocket # voice websocket is handled by the voice server - ws_fut = start_websocket( - app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler - ) + ws_fut = start_websocket(app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler) await ws_fut @@ -439,5 +438,6 @@ def handle_413(_): async def handle_500(_): return jsonify({"message": "500: Internal Server Error", "code": 0}), 500 + if __name__ == "__main__": app.run() diff --git a/tests/common.py b/tests/common.py index d1480a53..bcd23e5b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -150,9 +150,7 @@ async def delete(self): async def refetch(self) -> "WrappedGuild": async with self.test_cli.app.app_context(): - guild = await self.test_cli.app.storage.get_guild_full( - self.id, user_id=self.test_cli.user["id"] - ) + guild = await self.test_cli.app.storage.get_guild_full(self.id, user_id=self.test_cli.user["id"]) return WrappedGuild.from_json(self.test_cli, guild) @classmethod @@ -170,9 +168,7 @@ def from_json(cls, test_cli, rjson): "widget_channel_id": int_(rjson["widget_channel_id"]), "system_channel_id": int_(rjson["system_channel_id"]), "rules_channel_id": int_(rjson["rules_channel_id"]), - "public_updates_channel_id": int_( - rjson["public_updates_channel_id"] - ), + "public_updates_channel_id": int_(rjson["public_updates_channel_id"]), }, }, ) @@ -379,9 +375,7 @@ async def create_guild_channel( channel_id = self.app.winter_factory.snowflake() async with self.app.app_context(): - await create_guild_channel( - guild_id, channel_id, type, **{**{"name": name}, **kwargs} - ) + await create_guild_channel(guild_id, channel_id, type, **{**{"name": name}, **kwargs}) channel_data = await self.app.storage.get_channel(channel_id) return self.add_resource(WrappedGuildChannel.from_json(self, channel_data)) diff --git a/tests/test_admin_api/test_guilds.py b/tests/test_admin_api/test_guilds.py index 846fd754..1dd666fc 100644 --- a/tests/test_admin_api/test_guilds.py +++ b/tests/test_admin_api/test_guilds.py @@ -55,9 +55,7 @@ async def test_guild_update(test_cli_staff): # would be overkill to test the side-effects, so... I'm not # testing them. Yes, I know its a bad idea, but if someone has an easier # way to write that, do send an MR. - resp = await test_cli_staff.patch( - f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True} - ) + resp = await test_cli_staff.patch(f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True}) assert resp.status_code == 200 rjson = await resp.json @@ -92,9 +90,7 @@ async def test_guild_delete(test_cli_staff): async def test_guild_create_voice(test_cli_staff): region_id = secrets.token_hex(6) region_name = secrets.token_hex(6) - resp = await test_cli_staff.put( - "/api/v6/admin/voice/regions", json={"id": region_id, "name": region_name} - ) + resp = await test_cli_staff.put("/api/v6/admin/voice/regions", json={"id": region_id, "name": region_name}) assert resp.status_code == 200 rjson = await resp.json assert isinstance(rjson, list) diff --git a/tests/test_admin_api/test_instance_invites.py b/tests/test_admin_api/test_instance_invites.py index 6bd423a5..02528b8f 100644 --- a/tests/test_admin_api/test_instance_invites.py +++ b/tests/test_admin_api/test_instance_invites.py @@ -48,9 +48,7 @@ async def test_inv_delete_invalid(test_cli_staff): async def test_create_invite(test_cli_staff): """Test the creation of an instance invite, then listing it, then deleting it.""" - resp = await test_cli_staff.put( - "/api/v6/admin/instance/invites", json={"max_uses": 1} - ) + resp = await test_cli_staff.put("/api/v6/admin/instance/invites", json={"max_uses": 1}) assert resp.status_code == 200 rjson = await resp.json diff --git a/tests/test_admin_api/test_users.py b/tests/test_admin_api/test_users.py index 6fb7cd85..eed8f647 100644 --- a/tests/test_admin_api/test_users.py +++ b/tests/test_admin_api/test_users.py @@ -44,9 +44,7 @@ async def test_list_users(test_cli_staff): @pytest.mark.asyncio async def test_find_single_user(test_cli_staff): - user = await test_cli_staff.create_user( - username="test_user" + secrets.token_hex(2), email=email() - ) + user = await test_cli_staff.create_user(username="test_user" + secrets.token_hex(2), email=email()) resp = await _search(test_cli_staff, username=user.name) assert resp.status_code == 200 @@ -123,9 +121,7 @@ async def test_user_update(test_cli_staff): user = await test_cli_staff.create_user() # set them as partner flag - resp = await test_cli_staff.patch( - f"/api/v6/admin/users/{user.id}", json={"flags": UserFlags.partner} - ) + resp = await test_cli_staff.patch(f"/api/v6/admin/users/{user.id}", json={"flags": UserFlags.partner}) assert resp.status_code == 200 rjson = await resp.json diff --git a/tests/test_channels.py b/tests/test_channels.py index 4d2a70fc..3629239f 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -96,9 +96,7 @@ async def test_channel_message_send_on_new_channel(test_cli_user): async def test_channel_message_delete(test_cli_user): guild = await test_cli_user.create_guild() channel = await test_cli_user.create_guild_channel(guild_id=guild.id) - message = await test_cli_user.create_message( - guild_id=guild.id, channel_id=channel.id - ) + message = await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id) resp = await test_cli_user.delete( f"/api/v6/channels/{channel.id}/messages/{message.id}", @@ -115,9 +113,7 @@ async def test_channel_message_delete_different_author(test_cli_user): async with test_cli_user.app.app_context(): await add_member(guild.id, user.id) - message = await test_cli_user.create_message( - guild_id=guild.id, channel_id=channel.id, author_id=user.id - ) + message = await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id, author_id=user.id) resp = await test_cli_user.delete( f"/api/v6/channels/{channel.id}/messages/{message.id}", @@ -131,9 +127,7 @@ async def test_channel_message_bulk_delete(test_cli_user): channel = await test_cli_user.create_guild_channel(guild_id=guild.id) messages = [] for _ in range(10): - messages.append( - await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id) - ) + messages.append(await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id)) resp = await test_cli_user.post( f"/api/v6/channels/{channel.id}/messages/bulk-delete", diff --git a/tests/test_guild.py b/tests/test_guild.py index f776ad6c..a769b758 100644 --- a/tests/test_guild.py +++ b/tests/test_guild.py @@ -31,9 +31,7 @@ async def test_guild_create(test_cli_user): g_name = secrets.token_hex(5) # stage 1: create - resp = await test_cli_user.post( - "/api/v6/guilds", json={"name": g_name, "region": None} - ) + resp = await test_cli_user.post("/api/v6/guilds", json={"name": g_name, "region": None}) assert resp.status_code == 200 rjson = await resp.json diff --git a/tests/test_invites.py b/tests/test_invites.py index dbfeec69..84a845e0 100644 --- a/tests/test_invites.py +++ b/tests/test_invites.py @@ -24,9 +24,7 @@ async def _create_invite(test_cli_user, guild, channel, max_uses=0): - resp = await test_cli_user.post( - f'/api/v9/channels/{channel["id"]}/invites', json={"max_uses": max_uses} - ) + resp = await test_cli_user.post(f'/api/v9/channels/{channel["id"]}/invites', json={"max_uses": max_uses}) assert resp.status_code == 200 rjson = await resp.json @@ -111,9 +109,7 @@ async def test_leave_join_invite_cycle(test_cli_user): assert any(incoming_guild["id"] == str(guild.id) for incoming_guild in rjson) - resp = await test_cli_user.delete( - f"/api/v6/users/@me/guilds/{guild.id}", as_user=user - ) + resp = await test_cli_user.delete(f"/api/v6/users/@me/guilds/{guild.id}", as_user=user) assert resp.status_code == 204 resp = await test_cli_user.get("/api/v6/users/@me/guilds", as_user=user) @@ -134,9 +130,7 @@ async def test_invite_max_uses(test_cli_user): # join and leave await _join_invite(test_cli_user, invite, user) - resp = await test_cli_user.delete( - f"/api/v6/users/@me/guilds/{guild.id}", as_user=user - ) + resp = await test_cli_user.delete(f"/api/v6/users/@me/guilds/{guild.id}", as_user=user) assert resp.status_code == 204 resp = await test_cli_user.post(f'/api/v9/invites/{invite["code"]}', as_user=user) diff --git a/tests/test_messages.py b/tests/test_messages.py index 874a514f..1874714f 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -27,9 +27,7 @@ async def test_message_listing(test_cli_user): channel = await test_cli_user.create_guild_channel(guild_id=guild.id) messages = [] for _ in range(10): - messages.append( - await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id) - ) + messages.append(await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id)) # assert all messages we just created can be refetched if we give the # middle message to the 'around' parameter @@ -73,9 +71,7 @@ async def test_message_listing(test_cli_user): async def test_message_update(test_cli_user): guild = await test_cli_user.create_guild() channel = await test_cli_user.create_guild_channel(guild_id=guild.id) - message = await test_cli_user.create_message( - guild_id=guild.id, channel_id=channel.id - ) + message = await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id) resp = await test_cli_user.patch( f"/api/v6/channels/{channel.id}/messages/{message.id}", @@ -94,9 +90,7 @@ async def test_message_update(test_cli_user): async def test_message_pinning(test_cli_user): guild = await test_cli_user.create_guild() channel = await test_cli_user.create_guild_channel(guild_id=guild.id) - message = await test_cli_user.create_message( - guild_id=guild.id, channel_id=channel.id - ) + message = await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id) resp = await test_cli_user.put(f"/api/v6/channels/{channel.id}/pins/{message.id}") assert resp.status_code == 204 @@ -107,9 +101,7 @@ async def test_message_pinning(test_cli_user): assert len(rjson) == 1 assert rjson[0]["id"] == str(message.id) - resp = await test_cli_user.delete( - f"/api/v6/channels/{channel.id}/pins/{message.id}" - ) + resp = await test_cli_user.delete(f"/api/v6/channels/{channel.id}/pins/{message.id}") assert resp.status_code == 204 resp = await test_cli_user.get(f"/api/v6/channels/{channel.id}/pins") diff --git a/tests/test_reactions.py b/tests/test_reactions.py index f98bf155..c2db562d 100644 --- a/tests/test_reactions.py +++ b/tests/test_reactions.py @@ -26,20 +26,14 @@ async def test_reaction_flow(test_cli_user): guild = await test_cli_user.create_guild() channel = await test_cli_user.create_guild_channel(guild_id=guild.id) - message = await test_cli_user.create_message( - guild_id=guild.id, channel_id=channel.id - ) + message = await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id) reaction = urllib.parse.quote("\N{THINKING FACE}") - resp = await test_cli_user.put( - f"/api/v6/channels/{channel.id}/messages/{message.id}/reactions/{reaction}/@me" - ) + resp = await test_cli_user.put(f"/api/v6/channels/{channel.id}/messages/{message.id}/reactions/{reaction}/@me") assert resp.status_code == 204 - resp = await test_cli_user.get( - f"/api/v6/channels/{channel.id}/messages/{message.id}/reactions/{reaction}" - ) + resp = await test_cli_user.get(f"/api/v6/channels/{channel.id}/messages/{message.id}/reactions/{reaction}") assert resp.status_code == 200 rjson = await resp.json assert len(rjson) == 1 diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index f253df6f..a2d1cec9 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -26,9 +26,7 @@ async def test_webhook_flow(test_cli_user): guild = await test_cli_user.create_guild() channel = await test_cli_user.create_guild_channel(guild_id=guild.id) - resp = await test_cli_user.post( - f"/api/v6/channels/{channel.id}/webhooks", json={"name": "awooga"} - ) + resp = await test_cli_user.post(f"/api/v6/channels/{channel.id}/webhooks", json={"name": "awooga"}) assert resp.status_code == 200 rjson = await resp.json assert rjson["channel_id"] == str(channel.id) @@ -46,7 +44,5 @@ async def test_webhook_flow(test_cli_user): assert resp.status_code == 204 refetched_channel = await channel.refetch() - message = await test_cli_user.app.storage.get_message( - refetched_channel.last_message_id - ) + message = await test_cli_user.app.storage.get_message(refetched_channel.last_message_id) assert message["author"]["id"] == webhook_id diff --git a/tests/test_websocket.py b/tests/test_websocket.py index ee8d3fe9..be3f33b6 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -102,14 +102,10 @@ async def recv(self, *, expect=Message, process_event: bool = True): assert self.ws.state is ConnectionState.REMOTE_CLOSING await self.send(event.response()) if process_event: - raise websockets.ConnectionClosed( - RcvdWrapper(event.code, event.reason), None - ) + raise websockets.ConnectionClosed(RcvdWrapper(event.code, event.reason), None) if expect is not None and not isinstance(event, expect): - raise AssertionError( - f"Expected {expect!r} websocket event, got {type(event)!r}" - ) + raise AssertionError(f"Expected {expect!r} websocket event, got {type(event)!r}") # this keeps compatibility with code written for aaugustin/websockets if expect is Message and process_event: @@ -228,9 +224,7 @@ async def get_gw(test_cli, version: int) -> str: return gw_json["url"] -async def gw_start( - test_cli, *, version: int = 6, etf=False, compress: Optional[str] = None -): +async def gw_start(test_cli, *, version: int = 6, etf=False, compress: Optional[str] = None): """Start a websocket connection""" gw_url = await get_gw(test_cli, version) @@ -269,9 +263,7 @@ async def test_ready(test_cli_user): # get the hello frame but ignore it await _json(conn) - await _json_send( - conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} - ) + await _json_send(conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}}) # try to get a ready try: @@ -307,9 +299,7 @@ async def test_ready_fields(test_cli_user): # get the hello frame but ignore it await _json(conn) - await _json_send( - conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} - ) + await _json_send(conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}}) try: await extract_and_verify_ready(conn) @@ -322,9 +312,7 @@ async def test_ready_fields(test_cli_user): async def test_ready_v9(test_cli_user): conn = await gw_start(test_cli_user.cli, version=9) await _json(conn) - await _json_send( - conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} - ) + await _json_send(conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}}) try: ready = await _json(conn) @@ -353,9 +341,7 @@ async def test_heartbeat(test_cli_user): # get the hello frame but ignore it await _json(conn) - await _json_send( - conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} - ) + await _json_send(conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}}) # ignore ready data ready = await _json(conn) @@ -391,9 +377,7 @@ async def test_resume(test_cli_user): # get the hello frame but ignore it await _json(conn) - await _json_send( - conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} - ) + await _json_send(conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}}) try: ready = await _json(conn) @@ -464,9 +448,7 @@ async def test_resume(test_cli_user): async def test_ready_bot(test_cli_bot): conn = await gw_start(test_cli_bot.cli) await _json(conn) # ignore hello - await _json_send( - conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}} - ) + await _json_send(conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}}) try: await extract_and_verify_ready(conn)