Skip to content

Commit

Permalink
Cleaning up, formatting and types
Browse files Browse the repository at this point in the history
  • Loading branch information
gnunicorn committed May 11, 2024
1 parent e3c7cc9 commit 38a4bdd
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 104 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ version = {attr = "synapse_super_invites.__version__"}
[tool.mypy]
strict = true

[[tool.mypy.overrides]]
module = "tests.*"
disable_error_code = ["attr-defined", "index", "union-attr"]

[tool.ruff]
line-length = 88

Expand Down
7 changes: 6 additions & 1 deletion synapse_super_invites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from twisted.web.static import File

from .config import SynapseSuperInvitesConfig, run_alembic
from .resource import RedeemResource, TokenInfoResource, TokensResource, WebAccessResource
from .resource import (
RedeemResource,
TokenInfoResource,
TokensResource,
WebAccessResource,
)

__version__ = "0.8.3"

Expand Down
6 changes: 4 additions & 2 deletions synapse_super_invites/resource/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .redeem import RedeemResource
from .info import TokenInfoResource
from .redeem import RedeemResource
from .tokens import TokensResource
from .web_access import WebAccessResource
from .web_access import WebAccessResource

__all__ = ["RedeemResource", "TokenInfoResource", "TokensResource", "WebAccessResource"]
59 changes: 3 additions & 56 deletions synapse_super_invites/resource/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from sqlalchemy import func, select
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from synapse.http.server import (
DirectServeHtmlResource,
DirectServeJsonResource,
finish_request,
logger,
set_clickjacking_protection_headers,
)
from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import Any, JsonDict, Requester, Tuple # type: ignore[attr-defined]
from synapse.types import JsonDict, Requester

from synapse_super_invites.config import SynapseSuperInvitesConfig
from synapse_super_invites.model import Accepted, Room, Token
from synapse_super_invites.model import Token


def can_edit_token(token: Token, requester: Requester) -> bool:
Expand All @@ -36,53 +30,6 @@ def token_query(token_id: str): # type: ignore[no-untyped-def]
)


class AccessResource(DirectServeHtmlResource):
def __init__(
self,
api: ModuleApi,
):
super().__init__()
self.api = api

def _send_response(
self,
request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
"""Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
js_bytes = response_object

# The response code must always be set, for logging purposes.
request.setResponseCode(code)

# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return None

request.setHeader(b"Content-Type", b"text/javascript; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(js_bytes),))

# Ensure this content cannot be embedded.
set_clickjacking_protection_headers(request)

request.write(js_bytes)
finish_request(request)

async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# ensure logged int
_requester = await self.api.get_user_by_req(request, allow_guest=False)
access_token = await self.api._auth.get_access_token_from_request(request)
return 200, 'startApp("{t}")'.format(t=access_token)


class SuperInviteResourceBase(DirectServeJsonResource):
def __init__(
self, config: SynapseSuperInvitesConfig, api: ModuleApi, sessions: sessionmaker # type: ignore[type-arg]
Expand Down
41 changes: 25 additions & 16 deletions synapse_super_invites/resource/info.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,48 @@
from sqlalchemy import func, select
from synapse.http.servlet import parse_json_object_from_request, parse_string
from sqlalchemy import select
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Tuple, UserID # type: ignore[attr-defined]

from synapse_super_invites.model import Room, Token, Accepted
from synapse_super_invites.model import Accepted, Token

from .base import SuperInviteResourceBase, can_edit_token, serialize_token, token_query
from .base import SuperInviteResourceBase


class TokenInfoResource(SuperInviteResourceBase):

async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.api.get_user_by_req(request, allow_guest=False)
my_id = str(requester.user)
token_id = parse_string(request, "token", required=True)
with self.db.begin() as session:
token = session.scalar(select(Token).where(
Token.token == token_id # noqa: E711
))
token = session.scalar(
select(Token).where(Token.token == token_id) # noqa: E711
)
if not token:
return 403, {"error": "Token not found", "errcode": "NOT_FOUND"}

if token.deleted_at != None:
return 403, {"error": "Token not longer valid", "errcode": "CANT_REDEEM"}
if token.deleted_at is not None:
return 403, {
"error": "Token not longer valid",
"errcode": "CANT_REDEEM",
}

has_redeemed = session.scalar(
select(Accepted).where(Accepted.user == my_id, Accepted.token == token)
) != None
has_redeemed = (
session.scalar(
select(Accepted).where(
Accepted.user == my_id, Accepted.token == token
)
)
is not None
)

rooms_count = len(token.rooms)
if token.create_dm:
rooms_count += 1

user_id = token.owner
owner_info = await self.api._store.get_profileinfo(UserID.from_string(user_id))
owner_info = await self.api._store.get_profileinfo(
UserID.from_string(user_id)
)

return 200, {
"rooms_count": rooms_count,
Expand All @@ -43,5 +52,5 @@ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDic
"user_id": user_id,
"display_name": owner_info.display_name,
"avatar_url": owner_info.avatar_url,
}
}
},
}
7 changes: 6 additions & 1 deletion synapse_super_invites/resource/redeem.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDi

if token.create_dm:
dm_data = await self.api.create_room(
my_id, config={"preset": "trusted_private_chat", "invite": [owner], "is_direct": True}
my_id,
config={
"preset": "trusted_private_chat",
"invite": [owner],
"is_direct": True,
},
)
invited_rooms.append(dm_data[0])

Expand Down
4 changes: 3 additions & 1 deletion synapse_super_invites/resource/tokens.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from sqlalchemy import func, select
from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.http.site import SynapseRequest
Expand Down Expand Up @@ -97,7 +99,7 @@ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDi

token_data = serialize_token(token)

registration_token = {}
registration_token: dict[Any, Any] = {}
if as_registration_token:
if not self.config.generate_registration_token:
registration_token["valid"] = False
Expand Down
14 changes: 7 additions & 7 deletions synapse_super_invites/resource/web_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import Any, JsonDict, Tuple # type: ignore[attr-defined]
from synapse.types import Any, Tuple # type: ignore[attr-defined]


class WebAccessResource(DirectServeHtmlResource):
Expand All @@ -29,7 +29,7 @@ def _send_response(
js_bytes = response_object

# The response code must always be set, for logging purposes.
request.setResponseCode(code)
request.setResponseCode(code) # type: ignore[no-untyped-call]

# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
Expand All @@ -40,17 +40,17 @@ def _send_response(
)
return None

request.setHeader(b"Content-Type", b"text/javascript; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (len(js_bytes),))
request.setHeader(b"Content-Type", b"text/javascript; charset=utf-8") # type: ignore[no-untyped-call]
request.setHeader(b"Content-Length", b"%d" % (len(js_bytes),)) # type: ignore[no-untyped-call]

# Ensure this content cannot be embedded.
set_clickjacking_protection_headers(request)

request.write(js_bytes)
request.write(js_bytes) # type: ignore[no-untyped-call]
finish_request(request)

async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, str]:
# ensure logged int
_requester = await self.api.get_user_by_req(request, allow_guest=False)
access_token = await self.api._auth.get_access_token_from_request(request)
access_token = self.api._auth.get_access_token_from_request(request)
return 200, 'startApp("{t}")'.format(t=access_token)
47 changes: 27 additions & 20 deletions tests/test_integrations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from matrix_synapse_testutils.unittest import ( # type: ignore[import-untyped]
HomeserverTestCase,
override_config,
Expand Down Expand Up @@ -56,18 +58,24 @@ def create_room(self, user_id: str) -> str:
# create a room with the given access_token, return the roomId
def create_public_room(self, user_id: str) -> str:
room_id: str = self.get_success(
self.module_api.create_room(user_id=user_id, config={"preset": "public_chat", "visibility" : "public"}, ratelimit=False)
self.module_api.create_room(
user_id=user_id,
config={"preset": "public_chat", "visibility": "public"},
ratelimit=False,
)
)[0]
return room_id

def getState(self, room_data, type_key: str, state_key: str | None) -> Dict | None:
for e in reversed(room_data.get('timeline', {}).get('events', [])):
if e.get('type') == type_key:
if state_key != None:
if e.get('state_key') == state_key:
return e.get('content')
def getState(
self, room_data: dict[Any, Any], type_key: str, state_key: str | None
) -> Any | None:
for e in reversed(room_data.get("timeline", {}).get("events", [])):
if e.get("type") == type_key:
if state_key is not None:
if e.get("state_key") == state_key:
return e.get("content")
else:
return e.get('content')
return e.get("content")
return None


Expand Down Expand Up @@ -172,7 +180,6 @@ def test_simple_invite_token_test(self) -> None:
_f_id = self.register_user("flit", "flit")
f_access_token = self.login("flit", "flit")


channel = self.make_request(
"GET",
"/_synapse/client/super_invites/info?token={token}".format(token=token),
Expand All @@ -184,8 +191,8 @@ def test_simple_invite_token_test(self) -> None:
self.assertEqual(channel.json_body["rooms_count"], 3)
self.assertEqual(channel.json_body["create_dm"], False)
self.assertEqual(channel.json_body["has_redeemed"], False)
self.assertEqual(channel.json_body["inviter"]["user_id"], '@meeko:test')
self.assertEqual(channel.json_body["inviter"]["display_name"], 'meeko')
self.assertEqual(channel.json_body["inviter"]["user_id"], "@meeko:test")
self.assertEqual(channel.json_body["inviter"]["display_name"], "meeko")

channel = self.make_request(
"POST",
Expand Down Expand Up @@ -238,7 +245,7 @@ def test_simple_can_join_public_room_test(self) -> None:

# creating five channel
_roomA = self.create_room(m_id)
roomB = self.create_public_room(m_id) # this is public
roomB = self.create_public_room(m_id) # this is public
roomC = self.create_room(m_id)
roomD = self.create_room(m_id)
_roomE = self.create_room(m_id)
Expand Down Expand Up @@ -297,8 +304,8 @@ def test_simple_can_join_public_room_test(self) -> None:
)
# ensure the dm matches what we are expecting
public_room = channel.json_body["rooms"]["join"][roomB]
join_rule = self.getState(public_room, 'm.room.join_rules', None)
self.assertEquals(join_rule['join_rule'], 'public', join_rule)
join_rule = self.getState(public_room, "m.room.join_rules", None)
self.assertEquals(join_rule["join_rule"], "public", join_rule)

@override_config(
{
Expand Down Expand Up @@ -489,14 +496,14 @@ def test_simple_invite_token_only_dm_test(self) -> None:

# ensure the dm matches what we are expecting
dm = channel.json_body["rooms"]["join"][new_dm]
join_rule = self.getState(dm, 'm.room.join_rules', None)
self.assertEquals(join_rule['join_rule'], 'invite', join_rule)
join_rule = self.getState(dm, "m.room.join_rules", None)
self.assertEquals(join_rule["join_rule"], "invite", join_rule)

# and the other has been invited, too
member = self.getState(dm, 'm.room.member', '@meeko:test')
self.assertEquals(member.get('membership'), 'invite', member)
member = self.getState(dm, "m.room.member", "@meeko:test")
self.assertEquals(member.get("membership"), "invite", member)
# to the DM
self.assertEquals(member.get('is_direct'), True, member)
self.assertEquals(member.get("is_direct"), True, member)

@override_config(DEFAULT_CONFIG) # type: ignore[misc]
def test_simple_invite_token_with_dm_test(self) -> None:
Expand Down Expand Up @@ -712,7 +719,7 @@ def test_deletion(self) -> None:
"/_synapse/client/super_invites/info?token={token}".format(token=token),
access_token=f_access_token,
)
self.assertEqual(channel.code, 403, msg=channel.result) # access denied
self.assertEqual(channel.code, 403, msg=channel.result) # access denied

# and it can't be redeemed
channel = self.make_request(
Expand Down

0 comments on commit 38a4bdd

Please sign in to comment.