diff --git a/changelog.d/19152.feature b/changelog.d/19152.feature new file mode 100644 index 00000000000..833c1084fb5 --- /dev/null +++ b/changelog.d/19152.feature @@ -0,0 +1 @@ +Remove authentication from `POST /_matrix/client/v1/delayed_events`, and allow calling this endpoint with the update action to take (`send`/`cancel`/`restart`) in the request path instead of the body. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 7e3117b8d6b..79b2a0c528e 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -58,6 +58,7 @@ from synapse.storage.databases.main import FilteringWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore +from synapse.storage.databases.main.delayed_events import DelayedEventsStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore @@ -273,6 +274,7 @@ class Store( RelationsWorkerStore, EventFederationWorkerStore, SlidingSyncStore, + DelayedEventsStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 3342420d7d1..de21e3abbb7 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -21,6 +21,7 @@ from synapse.api.errors import ShadowBanError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME +from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions @@ -29,11 +30,9 @@ ) from synapse.storage.databases.main.delayed_events import ( DelayedEventDetails, - DelayID, EventType, StateKey, Timestamp, - UserLocalpart, ) from synapse.storage.databases.main.state_deltas import StateDelta from synapse.types import ( @@ -399,96 +398,63 @@ def on_added(self, next_send_ts: int) -> None: if self._next_send_ts_changed(next_send_ts): self._schedule_next_at(next_send_ts) - async def cancel(self, requester: Requester, delay_id: str) -> None: + async def cancel(self, request: SynapseRequest, delay_id: str) -> None: """ Cancels the scheduled delivery of the matching delayed event. - Args: - requester: The owner of the delayed event to act on. - delay_id: The ID of the delayed event to act on. - Raises: NotFoundError: if no matching delayed event could be found. """ assert self._is_master await self._delayed_event_mgmt_ratelimiter.ratelimit( - requester, - (requester.user.to_string(), requester.device_id), + None, request.getClientAddress().host ) await make_deferred_yieldable(self._initialized_from_db) - next_send_ts = await self._store.cancel_delayed_event( - delay_id=delay_id, - user_localpart=requester.user.localpart, - ) + next_send_ts = await self._store.cancel_delayed_event(delay_id) if self._next_send_ts_changed(next_send_ts): self._schedule_next_at_or_none(next_send_ts) - async def restart(self, requester: Requester, delay_id: str) -> None: + async def restart(self, request: SynapseRequest, delay_id: str) -> None: """ Restarts the scheduled delivery of the matching delayed event. - Args: - requester: The owner of the delayed event to act on. - delay_id: The ID of the delayed event to act on. - Raises: NotFoundError: if no matching delayed event could be found. """ assert self._is_master await self._delayed_event_mgmt_ratelimiter.ratelimit( - requester, - (requester.user.to_string(), requester.device_id), + None, request.getClientAddress().host ) await make_deferred_yieldable(self._initialized_from_db) next_send_ts = await self._store.restart_delayed_event( - delay_id=delay_id, - user_localpart=requester.user.localpart, - current_ts=self._get_current_ts(), + delay_id, self._get_current_ts() ) if self._next_send_ts_changed(next_send_ts): self._schedule_next_at(next_send_ts) - async def send(self, requester: Requester, delay_id: str) -> None: + async def send(self, request: SynapseRequest, delay_id: str) -> None: """ Immediately sends the matching delayed event, instead of waiting for its scheduled delivery. - Args: - requester: The owner of the delayed event to act on. - delay_id: The ID of the delayed event to act on. - Raises: NotFoundError: if no matching delayed event could be found. """ assert self._is_master - # Use standard request limiter for sending delayed events on-demand, - # as an on-demand send is similar to sending a regular event. - await self._request_ratelimiter.ratelimit(requester) + await self._delayed_event_mgmt_ratelimiter.ratelimit( + None, request.getClientAddress().host + ) await make_deferred_yieldable(self._initialized_from_db) - event, next_send_ts = await self._store.process_target_delayed_event( - delay_id=delay_id, - user_localpart=requester.user.localpart, - ) + event, next_send_ts = await self._store.process_target_delayed_event(delay_id) if self._next_send_ts_changed(next_send_ts): self._schedule_next_at_or_none(next_send_ts) - await self._send_event( - DelayedEventDetails( - delay_id=DelayID(delay_id), - user_localpart=UserLocalpart(requester.user.localpart), - room_id=event.room_id, - type=event.type, - state_key=event.state_key, - origin_server_ts=event.origin_server_ts, - content=event.content, - device_id=event.device_id, - ) - ) + await self._send_event(event) async def _send_on_timeout(self) -> None: self._next_delayed_event_call = None @@ -611,9 +577,7 @@ async def _send_event( finally: # TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure try: - await self._store.delete_processed_delayed_event( - event.delay_id, event.user_localpart - ) + await self._store.delete_processed_delayed_event(event.delay_id) except Exception: logger.exception("Failed to delete processed delayed event") diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 80abacbc9d6..69d1013e728 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -47,14 +47,11 @@ class UpdateDelayedEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.auth = hs.get_auth() self.delayed_events_handler = hs.get_delayed_events_handler() async def on_POST( self, request: SynapseRequest, delay_id: str ) -> tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - body = parse_json_object_from_request(request) try: action = str(body["action"]) @@ -75,11 +72,65 @@ async def on_POST( ) if enum_action == _UpdateDelayedEventAction.CANCEL: - await self.delayed_events_handler.cancel(requester, delay_id) + await self.delayed_events_handler.cancel(request, delay_id) elif enum_action == _UpdateDelayedEventAction.RESTART: - await self.delayed_events_handler.restart(requester, delay_id) + await self.delayed_events_handler.restart(request, delay_id) elif enum_action == _UpdateDelayedEventAction.SEND: - await self.delayed_events_handler.send(requester, delay_id) + await self.delayed_events_handler.send(request, delay_id) + return 200, {} + + +class CancelDelayedEventServlet(RestServlet): + PATTERNS = client_patterns( + r"/org\.matrix\.msc4140/delayed_events/(?P[^/]+)/cancel$", + releases=(), + ) + CATEGORY = "Delayed event management requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.delayed_events_handler = hs.get_delayed_events_handler() + + async def on_POST( + self, request: SynapseRequest, delay_id: str + ) -> tuple[int, JsonDict]: + await self.delayed_events_handler.cancel(request, delay_id) + return 200, {} + + +class RestartDelayedEventServlet(RestServlet): + PATTERNS = client_patterns( + r"/org\.matrix\.msc4140/delayed_events/(?P[^/]+)/restart$", + releases=(), + ) + CATEGORY = "Delayed event management requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.delayed_events_handler = hs.get_delayed_events_handler() + + async def on_POST( + self, request: SynapseRequest, delay_id: str + ) -> tuple[int, JsonDict]: + await self.delayed_events_handler.restart(request, delay_id) + return 200, {} + + +class SendDelayedEventServlet(RestServlet): + PATTERNS = client_patterns( + r"/org\.matrix\.msc4140/delayed_events/(?P[^/]+)/send$", + releases=(), + ) + CATEGORY = "Delayed event management requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.delayed_events_handler = hs.get_delayed_events_handler() + + async def on_POST( + self, request: SynapseRequest, delay_id: str + ) -> tuple[int, JsonDict]: + await self.delayed_events_handler.send(request, delay_id) return 200, {} @@ -108,4 +159,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: UpdateDelayedEventServlet(hs).register(http_server) + CancelDelayedEventServlet(hs).register(http_server) + RestartDelayedEventServlet(hs).register(http_server) + SendDelayedEventServlet(hs).register(http_server) DelayedEventsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index b11ed86db2b..7f72be46f57 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -13,18 +13,26 @@ # import logging -from typing import NewType +from typing import TYPE_CHECKING, NewType import attr from synapse.api.errors import NotFoundError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import LoggingTransaction, StoreError +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + StoreError, +) from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomID from synapse.util import stringutils from synapse.util.json import json_encoder +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -55,6 +63,27 @@ class DelayedEventDetails(EventDetails): class DelayedEventsStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + # Set delayed events to be uniquely identifiable by their delay_id. + # In practice, delay_ids are already unique because they are generated + # from cryptographically strong random strings. + # Therefore, adding this constraint is not expected to ever fail, + # despite the current pkey technically allowing non-unique delay_ids. + self.db_pool.updates.register_background_index_update( + update_name="delayed_events_idx", + index_name="delayed_events_idx", + table="delayed_events", + columns=("delay_id",), + unique=True, + ) + async def get_delayed_events_stream_pos(self) -> int: """ Gets the stream position of the background process to watch for state events @@ -134,9 +163,7 @@ def add_delayed_event_txn(txn: LoggingTransaction) -> Timestamp: async def restart_delayed_event( self, - *, delay_id: str, - user_localpart: str, current_ts: Timestamp, ) -> Timestamp: """ @@ -145,7 +172,6 @@ async def restart_delayed_event( Args: delay_id: The ID of the delayed event to restart. - user_localpart: The localpart of the delayed event's owner. current_ts: The current time, which will be used to calculate the new send time. Returns: The send time of the next delayed event to be sent, @@ -163,13 +189,11 @@ def restart_delayed_event_txn( """ UPDATE delayed_events SET send_ts = ? + delay - WHERE delay_id = ? AND user_localpart = ? - AND NOT is_processed + WHERE delay_id = ? AND NOT is_processed """, ( current_ts, delay_id, - user_localpart, ), ) if txn.rowcount == 0: @@ -319,21 +343,15 @@ def process_timeout_delayed_events_txn( async def process_target_delayed_event( self, - *, delay_id: str, - user_localpart: str, ) -> tuple[ - EventDetails, + DelayedEventDetails, Timestamp | None, ]: """ Marks for processing the matching delayed event, regardless of its timeout time, as long as it has not already been marked as such. - Args: - delay_id: The ID of the delayed event to restart. - user_localpart: The localpart of the delayed event's owner. - Returns: The details of the matching delayed event, and the send time of the next delayed event to be sent, if any. @@ -344,39 +362,38 @@ async def process_target_delayed_event( def process_target_delayed_event_txn( txn: LoggingTransaction, ) -> tuple[ - EventDetails, + DelayedEventDetails, Timestamp | None, ]: txn.execute( """ UPDATE delayed_events SET is_processed = TRUE - WHERE delay_id = ? AND user_localpart = ? - AND NOT is_processed + WHERE delay_id = ? AND NOT is_processed RETURNING room_id, event_type, state_key, origin_server_ts, content, - device_id + device_id, + user_localpart """, - ( - delay_id, - user_localpart, - ), + (delay_id,), ) row = txn.fetchone() if row is None: raise NotFoundError("Delayed event not found") - event = EventDetails( + event = DelayedEventDetails( RoomID.from_string(row[0]), EventType(row[1]), StateKey(row[2]) if row[2] is not None else None, Timestamp(row[3]) if row[3] is not None else None, db_to_json(row[4]), DeviceID(row[5]) if row[5] is not None else None, + DelayID(delay_id), + UserLocalpart(row[6]), ) return event, self._get_next_delayed_event_send_ts_txn(txn) @@ -385,19 +402,10 @@ def process_target_delayed_event_txn( "process_target_delayed_event", process_target_delayed_event_txn ) - async def cancel_delayed_event( - self, - *, - delay_id: str, - user_localpart: str, - ) -> Timestamp | None: + async def cancel_delayed_event(self, delay_id: str) -> Timestamp | None: """ Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed. - Args: - delay_id: The ID of the delayed event to restart. - user_localpart: The localpart of the delayed event's owner. - Returns: The send time of the next delayed event to be sent, if any. Raises: @@ -413,7 +421,6 @@ def cancel_delayed_event_txn( table="delayed_events", keyvalues={ "delay_id": delay_id, - "user_localpart": user_localpart, "is_processed": False, }, ) @@ -473,11 +480,7 @@ def cancel_delayed_state_events_txn( "cancel_delayed_state_events", cancel_delayed_state_events_txn ) - async def delete_processed_delayed_event( - self, - delay_id: DelayID, - user_localpart: UserLocalpart, - ) -> None: + async def delete_processed_delayed_event(self, delay_id: DelayID) -> None: """ Delete the matching delayed event, as long as it has been marked as processed. @@ -488,7 +491,6 @@ async def delete_processed_delayed_event( table="delayed_events", keyvalues={ "delay_id": delay_id, - "user_localpart": user_localpart, "is_processed": True, }, desc="delete_processed_delayed_event", @@ -554,7 +556,7 @@ def _generate_delay_id() -> DelayID: # We use the following format for delay IDs: # syd_ - # They are scoped to user localparts, so it is possible for - # the same ID to exist for multiple users. + # They are not scoped to user localparts, but the random string + # is expected to be sufficiently random to be globally unique. return DelayID(f"syd_{stringutils.random_string(20)}") diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 3c3b13437ef..c4c4d7bcc4a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 92 # remember to update the list below when updating +SCHEMA_VERSION = 93 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -168,11 +168,15 @@ Changes in SCHEMA_VERSION = 92 - Cleaned up a trigger that was added in #18260 and then reverted. + +Changes in SCHEMA_VERSION = 93 + - MSC4140: Set delayed events to be uniquely identifiable by their delay ID. """ SCHEMA_COMPAT_VERSION = ( # Transitive links are no longer written to `event_auth_chain_links` + # TODO: On the next compat bump, update the primary key of `delayed_events` 84 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/93/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/93/01_add_delayed_events.sql new file mode 100644 index 00000000000..c7f3c00612d --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_add_delayed_events.sql @@ -0,0 +1,15 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 Element Creations, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (9301, 'delayed_events_idx', '{}'); \ No newline at end of file diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index c67ffc76683..cc983ea1016 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -28,6 +28,7 @@ from synapse.util.clock import Clock from tests import unittest +from tests.server import FakeChannel from tests.unittest import HomeserverTestCase PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events" @@ -127,6 +128,10 @@ def test_delayed_state_events_are_sent_on_timeout(self) -> None: ) self.assertEqual(setter_expected, content.get(setter_key), content) + def test_get_delayed_events_auth(self) -> None: + channel = self.make_request("GET", PATH_PREFIX) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, channel.result) + @unittest.override_config( {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} ) @@ -154,7 +159,6 @@ def test_update_delayed_event_without_id(self) -> None: channel = self.make_request( "POST", f"{PATH_PREFIX}/", - access_token=self.user1_access_token, ) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) @@ -162,7 +166,6 @@ def test_update_delayed_event_without_body(self) -> None: channel = self.make_request( "POST", f"{PATH_PREFIX}/abc", - access_token=self.user1_access_token, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( @@ -175,7 +178,6 @@ def test_update_delayed_event_without_action(self) -> None: "POST", f"{PATH_PREFIX}/abc", {}, - self.user1_access_token, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( @@ -188,7 +190,6 @@ def test_update_delayed_event_with_invalid_action(self) -> None: "POST", f"{PATH_PREFIX}/abc", {"action": "oops"}, - self.user1_access_token, ) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( @@ -196,17 +197,21 @@ def test_update_delayed_event_with_invalid_action(self) -> None: channel.json_body["errcode"], ) - @parameterized.expand(["cancel", "restart", "send"]) - def test_update_delayed_event_without_match(self, action: str) -> None: - channel = self.make_request( - "POST", - f"{PATH_PREFIX}/abc", - {"action": action}, - self.user1_access_token, + @parameterized.expand( + ( + (action, action_in_path) + for action in ("cancel", "restart", "send") + for action_in_path in (True, False) ) + ) + def test_update_delayed_event_without_match( + self, action: str, action_in_path: bool + ) -> None: + channel = self._update_delayed_event("abc", action, action_in_path) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) - def test_cancel_delayed_state_event(self) -> None: + @parameterized.expand((True, False)) + def test_cancel_delayed_state_event(self, action_in_path: bool) -> None: state_key = "to_never_send" setter_key = "setter" @@ -221,7 +226,7 @@ def test_cancel_delayed_state_event(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") - self.assertIsNotNone(delay_id) + assert delay_id is not None self.reactor.advance(1) events = self._get_delayed_events() @@ -236,12 +241,7 @@ def test_cancel_delayed_state_event(self) -> None: expect_code=HTTPStatus.NOT_FOUND, ) - channel = self.make_request( - "POST", - f"{PATH_PREFIX}/{delay_id}", - {"action": "cancel"}, - self.user1_access_token, - ) + channel = self._update_delayed_event(delay_id, "cancel", action_in_path) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertListEqual([], self._get_delayed_events()) @@ -254,10 +254,11 @@ def test_cancel_delayed_state_event(self) -> None: expect_code=HTTPStatus.NOT_FOUND, ) + @parameterized.expand((True, False)) @unittest.override_config( {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} ) - def test_cancel_delayed_event_ratelimit(self) -> None: + def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None: delay_ids = [] for _ in range(2): channel = self.make_request( @@ -268,38 +269,17 @@ def test_cancel_delayed_event_ratelimit(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") - self.assertIsNotNone(delay_id) + assert delay_id is not None delay_ids.append(delay_id) - channel = self.make_request( - "POST", - f"{PATH_PREFIX}/{delay_ids.pop(0)}", - {"action": "cancel"}, - self.user1_access_token, - ) + channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - args = ( - "POST", - f"{PATH_PREFIX}/{delay_ids.pop(0)}", - {"action": "cancel"}, - self.user1_access_token, - ) - channel = self.make_request(*args) + channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) - # Add the current user to the ratelimit overrides, allowing them no ratelimiting. - self.get_success( - self.hs.get_datastores().main.set_ratelimit_for_user( - self.user1_user_id, 0, 0 - ) - ) - - # Test that the request isn't ratelimited anymore. - channel = self.make_request(*args) - self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - - def test_send_delayed_state_event(self) -> None: + @parameterized.expand((True, False)) + def test_send_delayed_state_event(self, action_in_path: bool) -> None: state_key = "to_send_on_request" setter_key = "setter" @@ -314,7 +294,7 @@ def test_send_delayed_state_event(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") - self.assertIsNotNone(delay_id) + assert delay_id is not None self.reactor.advance(1) events = self._get_delayed_events() @@ -329,12 +309,7 @@ def test_send_delayed_state_event(self) -> None: expect_code=HTTPStatus.NOT_FOUND, ) - channel = self.make_request( - "POST", - f"{PATH_PREFIX}/{delay_id}", - {"action": "send"}, - self.user1_access_token, - ) + channel = self._update_delayed_event(delay_id, "send", action_in_path) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertListEqual([], self._get_delayed_events()) content = self.helper.get_state( @@ -345,8 +320,9 @@ def test_send_delayed_state_event(self) -> None: ) self.assertEqual(setter_expected, content.get(setter_key), content) - @unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}}) - def test_send_delayed_event_ratelimit(self) -> None: + @parameterized.expand((True, False)) + @unittest.override_config({"rc_message": {"per_second": 2.5, "burst_count": 3}}) + def test_send_delayed_event_ratelimit(self, action_in_path: bool) -> None: delay_ids = [] for _ in range(2): channel = self.make_request( @@ -357,38 +333,17 @@ def test_send_delayed_event_ratelimit(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") - self.assertIsNotNone(delay_id) + assert delay_id is not None delay_ids.append(delay_id) - channel = self.make_request( - "POST", - f"{PATH_PREFIX}/{delay_ids.pop(0)}", - {"action": "send"}, - self.user1_access_token, - ) + channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - args = ( - "POST", - f"{PATH_PREFIX}/{delay_ids.pop(0)}", - {"action": "send"}, - self.user1_access_token, - ) - channel = self.make_request(*args) + channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) - # Add the current user to the ratelimit overrides, allowing them no ratelimiting. - self.get_success( - self.hs.get_datastores().main.set_ratelimit_for_user( - self.user1_user_id, 0, 0 - ) - ) - - # Test that the request isn't ratelimited anymore. - channel = self.make_request(*args) - self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - - def test_restart_delayed_state_event(self) -> None: + @parameterized.expand((True, False)) + def test_restart_delayed_state_event(self, action_in_path: bool) -> None: state_key = "to_send_on_restarted_timeout" setter_key = "setter" @@ -403,7 +358,7 @@ def test_restart_delayed_state_event(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") - self.assertIsNotNone(delay_id) + assert delay_id is not None self.reactor.advance(1) events = self._get_delayed_events() @@ -418,12 +373,7 @@ def test_restart_delayed_state_event(self) -> None: expect_code=HTTPStatus.NOT_FOUND, ) - channel = self.make_request( - "POST", - f"{PATH_PREFIX}/{delay_id}", - {"action": "restart"}, - self.user1_access_token, - ) + channel = self._update_delayed_event(delay_id, "restart", action_in_path) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.reactor.advance(1) @@ -449,10 +399,11 @@ def test_restart_delayed_state_event(self) -> None: ) self.assertEqual(setter_expected, content.get(setter_key), content) + @parameterized.expand((True, False)) @unittest.override_config( {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} ) - def test_restart_delayed_event_ratelimit(self) -> None: + def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None: delay_ids = [] for _ in range(2): channel = self.make_request( @@ -463,37 +414,19 @@ def test_restart_delayed_event_ratelimit(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) delay_id = channel.json_body.get("delay_id") - self.assertIsNotNone(delay_id) + assert delay_id is not None delay_ids.append(delay_id) - channel = self.make_request( - "POST", - f"{PATH_PREFIX}/{delay_ids.pop(0)}", - {"action": "restart"}, - self.user1_access_token, + channel = self._update_delayed_event( + delay_ids.pop(0), "restart", action_in_path ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - args = ( - "POST", - f"{PATH_PREFIX}/{delay_ids.pop(0)}", - {"action": "restart"}, - self.user1_access_token, + channel = self._update_delayed_event( + delay_ids.pop(0), "restart", action_in_path ) - channel = self.make_request(*args) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) - # Add the current user to the ratelimit overrides, allowing them no ratelimiting. - self.get_success( - self.hs.get_datastores().main.set_ratelimit_for_user( - self.user1_user_id, 0, 0 - ) - ) - - # Test that the request isn't ratelimited anymore. - channel = self.make_request(*args) - self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - def test_delayed_state_is_not_cancelled_by_new_state_from_same_user( self, ) -> None: @@ -598,6 +531,17 @@ def _get_delayed_event_content(self, event: JsonDict) -> JsonDict: return content + def _update_delayed_event( + self, delay_id: str, action: str, action_in_path: bool + ) -> FakeChannel: + path = f"{PATH_PREFIX}/{delay_id}" + body = {} + if action_in_path: + path += f"/{action}" + else: + body["action"] = action + return self.make_request("POST", path, body) + def _get_path_for_delayed_state( room_id: str, event_type: str, state_key: str, delay_ms: int