diff --git a/changelog.d/19041.feature b/changelog.d/19041.feature new file mode 100644 index 00000000000..e04b0189f8c --- /dev/null +++ b/changelog.d/19041.feature @@ -0,0 +1 @@ +Add companion endpoint for MSC4360: Sliding Sync Threads Extension. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 49943cf0c34..788f87602f0 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -20,17 +20,22 @@ import logging import re -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple from synapse.api.constants import Direction -from synapse.handlers.relations import ThreadsListInclude +from synapse.api.errors import SynapseError +from synapse.events.utils import SerializeEventConfig +from synapse.handlers.relations import BundledAggregations, ThreadsListInclude from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.databases.main.relations import ThreadsNextBatch -from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict +from synapse.storage.databases.main.relations import ThreadsNextBatch, ThreadUpdateInfo +from synapse.streams.config import ( + PaginationConfig, + extract_stream_token_from_pagination_token, +) +from synapse.types import JsonDict, RoomStreamToken, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer @@ -133,6 +138,187 @@ async def on_GET( return 200, result +class ThreadUpdatesServlet(RestServlet): + """ + Companion endpoint to the Sliding Sync threads extension (MSC4360). + Allows clients to bulk fetch thread updates across all joined rooms. + """ + + PATTERNS = client_patterns( + "/io.element.msc4360/thread_updates$", + unstable=True, + releases=(), + ) + CATEGORY = "Client API requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.clock = hs.get_clock() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.relations_handler = hs.get_relations_handler() + self.event_serializer = hs.get_event_client_serializer() + + async def _serialize_thread_updates( + self, + thread_updates: Sequence[ThreadUpdateInfo], + bundled_aggregations: Dict[str, BundledAggregations], + time_now: int, + serialize_options: SerializeEventConfig, + ) -> Dict[str, Dict[str, JsonDict]]: + """ + Serialize thread updates into the response format. + + Args: + thread_updates: List of thread update info from storage + bundled_aggregations: Map of event_id to bundled aggregations + time_now: Current time in milliseconds + serialize_options: Serialization configuration + + Returns: + Nested dict mapping room_id -> thread_root_id -> thread update dict + """ + chunk: Dict[str, Dict[str, JsonDict]] = {} + + for update in thread_updates: + room_id = update.room_id + thread_id = update.thread_id + + if room_id not in chunk: + chunk[room_id] = {} + + update_dict: JsonDict = {} + + # Serialize thread root if present + if update.thread_root_event is not None: + bundle_aggs_map = ( + {thread_id: bundled_aggregations[thread_id]} + if thread_id in bundled_aggregations + else None + ) + serialized_events = await self.event_serializer.serialize_events( + [update.thread_root_event], + time_now, + config=serialize_options, + bundle_aggregations=bundle_aggs_map, + ) + if serialized_events: + update_dict["thread_root"] = serialized_events[0] + + # Add per-thread prev_batch if present + if update.prev_batch is not None: + update_dict["prev_batch"] = await update.prev_batch.to_string( + self.store + ) + + chunk[room_id][thread_id] = update_dict + + return chunk + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + # Parse parameters + dir_str = parse_string(request, "dir", default="b") + if dir_str != "b": + raise SynapseError( + 400, + "The 'dir' parameter must be 'b' (backward). Forward pagination is not supported.", + ) + + limit = parse_integer(request, "limit", default=100) + if limit <= 0: + raise SynapseError(400, "The 'limit' parameter must be positive.") + + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + # Parse pagination tokens + from_token: Optional[RoomStreamToken] = None + to_token: Optional[RoomStreamToken] = None + + if from_token_str: + try: + stream_token_str = extract_stream_token_from_pagination_token( + from_token_str + ) + stream_token = await StreamToken.from_string( + self.store, stream_token_str + ) + from_token = stream_token.room_key + except Exception: + raise SynapseError(400, "'from' parameter is invalid") + + if to_token_str: + try: + stream_token_str = extract_stream_token_from_pagination_token( + to_token_str + ) + stream_token = await StreamToken.from_string( + self.store, stream_token_str + ) + to_token = stream_token.room_key + except Exception: + raise SynapseError(400, "'to' parameter is invalid") + + # Fetch thread updates from storage + # For backward pagination: + # - 'from' (upper bound, exclusive) maps to 'to_token' (inclusive with <=) + # Since next_batch is (last_returned - 1), <= excludes the last returned item + # - 'to' (lower bound, exclusive) maps to 'from_token' (exclusive with >) + thread_updates, next_token = await self.store.get_thread_updates_for_user( + user_id=requester.user.to_string(), + from_token=to_token, + to_token=from_token, + limit=limit, + include_thread_roots=True, + ) + + # Serialize response + chunk: Dict[str, Dict[str, JsonDict]] = {} + + if thread_updates: + # Get bundled aggregations for all thread roots + thread_root_events = [ + update.thread_root_event + for update in thread_updates + if update.thread_root_event is not None + ] + + bundled_aggregations = {} + if thread_root_events: + bundled_aggregations = ( + await self.relations_handler.get_bundled_aggregations( + thread_root_events, requester.user.to_string() + ) + ) + + # Set up serialization + time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig( + requester=requester, + ) + + # Serialize all thread updates + chunk = await self._serialize_thread_updates( + thread_updates=thread_updates, + bundled_aggregations=bundled_aggregations, + time_now=time_now, + serialize_options=serialize_options, + ) + + # Build response + response: JsonDict = {"chunk": chunk} + + # Add next_batch token for pagination + if next_token is not None: + response["next_batch"] = await next_token.to_string(self.store) + + return 200, response + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) ThreadsServlet(hs).register(http_server) + if hs.config.experimental.msc4360_enabled: + ThreadUpdatesServlet(hs).register(http_server) diff --git a/synapse/streams/config.py b/synapse/streams/config.py index ced17234778..53751e7624e 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -36,6 +36,32 @@ MAX_LIMIT = 1000 +def extract_stream_token_from_pagination_token(token_str: str) -> str: + """ + Extract the StreamToken portion from a pagination token string. + + Handles both: + - StreamToken format: "s123_456_..." + - SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /) + + This allows clients using sliding sync to use their pos tokens + with endpoints like /relations and /messages. + + Args: + token_str: The token string to parse + + Returns: + The StreamToken portion of the token + """ + if "/" in token_str: + # SlidingSyncStreamToken format: "connection_position/stream_token" + # Split and return just the stream_token part + parts = token_str.split("/", 1) + if len(parts) == 2: + return parts[1] + return token_str + + @attr.s(slots=True, auto_attribs=True) class PaginationConfig: """A configuration object which stores pagination parameters.""" @@ -58,32 +84,14 @@ async def from_request( from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") - # Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format - def extract_stream_token(token_str: str) -> str: - """ - Extract the StreamToken portion from a token string. - - Handles both: - - StreamToken format: "s123_456_..." - - SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /) - - This allows clients using sliding sync to use their pos tokens - with endpoints like /relations and /messages. - """ - if "/" in token_str: - # SlidingSyncStreamToken format: "connection_position/stream_token" - # Split and return just the stream_token part - parts = token_str.split("/", 1) - if len(parts) == 2: - return parts[1] - return token_str - try: from_tok = None if from_tok_str == "END": from_tok = None # For backwards compat. elif from_tok_str: - stream_token_str = extract_stream_token(from_tok_str) + stream_token_str = extract_stream_token_from_pagination_token( + from_tok_str + ) from_tok = await StreamToken.from_string(store, stream_token_str) except Exception: raise SynapseError(400, "'from' parameter is invalid") @@ -91,7 +99,9 @@ def extract_stream_token(token_str: str) -> str: try: to_tok = None if to_tok_str: - stream_token_str = extract_stream_token(to_tok_str) + stream_token_str = extract_stream_token_from_pagination_token( + to_tok_str + ) to_tok = await StreamToken.from_string(store, stream_token_str) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/tests/rest/client/sliding_sync/test_extension_threads.py b/tests/rest/client/sliding_sync/test_extension_threads.py index 3855b036721..eb2d2b4a3e8 100644 --- a/tests/rest/client/sliding_sync/test_extension_threads.py +++ b/tests/rest/client/sliding_sync/test_extension_threads.py @@ -834,3 +834,224 @@ def test_thread_in_timeline_included_with_include_roots(self) -> None: ) # Verify the thread root event is present self.assertIn("thread_root", thread_updates[thread_root_id]) + + def test_thread_updates_initial_sync(self) -> None: + """ + Test that prev_batch from the threads extension response can be used + with the /thread_updates endpoint to get additional thread updates during + initial sync. This verifies: + 1. The from parameter boundary is exclusive (no duplicates) + 2. Using prev_batch as 'from' provides complete coverage (no gaps) + 3. Works correctly with different numbers of threads + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create 5 thread roots + thread_ids = [] + for i in range(5): + thread_root_id = self.helper.send( + room_id, body=f"Thread {i}", tok=user1_tok + )["event_id"] + thread_ids.append(thread_root_id) + + # Add reply to each thread + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply to thread {i}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Do initial sync with threads extension enabled and limit=2 + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + "limit": 2, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Should get 2 thread updates + thread_updates = response_body["extensions"][EXT_NAME]["updates"][room_id] + self.assertEqual(len(thread_updates), 2) + first_sync_threads = set(thread_updates.keys()) + + # Get the top-level prev_batch token from the extension + self.assertIn("prev_batch", response_body["extensions"][EXT_NAME]) + prev_batch = response_body["extensions"][EXT_NAME]["prev_batch"] + + # Use prev_batch with /thread_updates endpoint to get remaining updates + # Note: prev_batch should be used as 'from' parameter (upper bound for backward pagination) + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from={prev_batch}", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200) + + # Should get the remaining 3 thread updates + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + self.assertEqual(len(chunk[room_id]), 3) + + thread_updates_response_threads = set(chunk[room_id].keys()) + + # Verify no overlap - the from parameter boundary should be exclusive + self.assertEqual( + len(first_sync_threads & thread_updates_response_threads), + 0, + "from parameter boundary should be exclusive - no thread should appear in both responses", + ) + + # Verify no gaps - all threads should be accounted for + all_threads = set(thread_ids) + combined_threads = first_sync_threads | thread_updates_response_threads + self.assertEqual( + combined_threads, + all_threads, + "Combined responses should include all thread updates with no gaps", + ) + + def test_thread_updates_incremental_sync(self) -> None: + """ + Test the intended usage pattern from MSC4360: using prev_batch as 'from' + and a previous sync pos as 'to' with /thread_updates to fill gaps between + syncs. This verifies that using both bounds together provides complete + coverage with no gaps or duplicates. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create 3 threads initially + initial_thread_ids = [] + for i in range(3): + thread_root_id = self.helper.send( + room_id, body=f"Thread {i}", tok=user1_tok + )["event_id"] + initial_thread_ids.append(thread_root_id) + + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply to thread {i}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # First sync + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, pos1 = self.do_sync(sync_body, tok=user1_tok) + + # Should get 3 thread updates + first_sync_threads = set( + response_body["extensions"][EXT_NAME]["updates"][room_id].keys() + ) + self.assertEqual(len(first_sync_threads), 3) + + # Create 3 more threads after the first sync + new_thread_ids = [] + for i in range(3, 6): + thread_root_id = self.helper.send( + room_id, body=f"Thread {i}", tok=user1_tok + )["event_id"] + new_thread_ids.append(thread_root_id) + + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply to thread {i}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Second sync with limit=1 to get only some of the new threads + sync_body_with_limit = { + "extensions": { + EXT_NAME: { + "enabled": True, + "limit": 1, + } + }, + } + response_body, pos2 = self.do_sync( + sync_body_with_limit, tok=user1_tok, since=pos1 + ) + + # Should get 1 thread update + second_sync_threads = set( + response_body["extensions"][EXT_NAME]["updates"][room_id].keys() + ) + self.assertEqual(len(second_sync_threads), 1) + + # Get prev_batch from the extension + self.assertIn("prev_batch", response_body["extensions"][EXT_NAME]) + prev_batch = response_body["extensions"][EXT_NAME]["prev_batch"] + + # Now use /thread_updates with from=prev_batch and to=pos1 + # This should get the 2 remaining new threads (created after pos1, not returned in second sync) + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from={prev_batch}&to={pos1}", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200) + + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + thread_updates_threads = set(chunk[room_id].keys()) + + # Should get exactly 2 threads + self.assertEqual(len(thread_updates_threads), 2) + + # Verify no overlap with second sync + self.assertEqual( + len(second_sync_threads & thread_updates_threads), + 0, + "No thread should appear in both second sync and thread_updates responses", + ) + + # Verify no overlap with first sync (to=pos1 should exclude those) + self.assertEqual( + len(first_sync_threads & thread_updates_threads), + 0, + "Threads from first sync should not appear in thread_updates (to=pos1 excludes them)", + ) + + # Verify no gaps - all new threads should be accounted for + all_new_threads = set(new_thread_ids) + combined_new_threads = second_sync_threads | thread_updates_threads + self.assertEqual( + combined_new_threads, + all_new_threads, + "Combined responses should include all new thread updates with no gaps", + ) diff --git a/tests/rest/client/test_thread_updates.py b/tests/rest/client/test_thread_updates.py new file mode 100644 index 00000000000..011f12d28ec --- /dev/null +++ b/tests/rest/client/test_thread_updates.py @@ -0,0 +1,597 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +import logging + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.api.constants import RelationTypes +from synapse.rest.client import login, relations, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.clock import Clock + +from tests import unittest + +logger = logging.getLogger(__name__) + + +class ThreadUpdatesTestCase(unittest.HomeserverTestCase): + """ + Test the /thread_updates companion endpoint (MSC4360). + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + relations.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4360_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + def test_no_updates_for_new_user(self) -> None: + """ + Test that a user with no thread updates gets an empty response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Request thread updates + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert empty chunk and no next_batch + self.assertEqual(channel.json_body["chunk"], {}) + self.assertNotIn("next_batch", channel.json_body) + + def test_single_thread_update(self) -> None: + """ + Test that a single thread with one reply appears in the response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Add reply to thread + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert thread is present + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + self.assertIn(thread_root_id, chunk[room_id]) + + # Assert thread root is included + thread_update = chunk[room_id][thread_root_id] + self.assertIn("thread_root", thread_update) + self.assertEqual(thread_update["thread_root"]["event_id"], thread_root_id) + + # Assert prev_batch is NOT present (only 1 update - the reply) + self.assertNotIn("prev_batch", thread_update) + + def test_multiple_threads_single_room(self) -> None: + """ + Test that multiple threads in the same room are grouped correctly. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create two threads + thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[ + "event_id" + ] + thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[ + "event_id" + ] + + # Add replies to both threads + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread1_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 2", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread2_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert both threads are in the same room + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + self.assertEqual(len(chunk), 1, "Should only have one room") + self.assertEqual(len(chunk[room_id]), 2, "Should have two threads") + self.assertIn(thread1_root_id, chunk[room_id]) + self.assertIn(thread2_root_id, chunk[room_id]) + + def test_threads_across_multiple_rooms(self) -> None: + """ + Test that threads from different rooms are grouped by room_id. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in both rooms + thread_a_root_id = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[ + "event_id" + ] + thread_b_root_id = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[ + "event_id" + ] + + # Add replies + self.helper.send_event( + room_a_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to A", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_a_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_b_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to B", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_b_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert both rooms are present with their threads + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, "Should have two rooms") + self.assertIn(room_a_id, chunk) + self.assertIn(room_b_id, chunk) + self.assertIn(thread_a_root_id, chunk[room_a_id]) + self.assertIn(thread_b_root_id, chunk[room_b_id]) + + def test_pagination_with_from_token(self) -> None: + """ + Test that pagination works using the next_batch token. + This verifies that multiple calls to /thread_updates return all thread + updates with no duplicates and no gaps. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create many threads (more than default limit) + thread_ids = [] + for i in range(5): + thread_root_id = self.helper.send( + room_id, body=f"Thread {i}", tok=user1_tok + )["event_id"] + thread_ids.append(thread_root_id) + + # Add reply + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply to thread {i}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request first page with small limit + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Should have 2 threads and a next_batch token + first_page_threads = set(channel.json_body["chunk"][room_id].keys()) + self.assertEqual(len(first_page_threads), 2) + self.assertIn("next_batch", channel.json_body) + + next_batch = channel.json_body["next_batch"] + + # Request second page + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + second_page_threads = set(channel.json_body["chunk"][room_id].keys()) + self.assertEqual(len(second_page_threads), 2) + + # Verify no overlap + self.assertEqual( + len(first_page_threads & second_page_threads), + 0, + "Pages should not have overlapping threads", + ) + + # Request third page to get the remaining thread + self.assertIn("next_batch", channel.json_body) + next_batch_2 = channel.json_body["next_batch"] + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + third_page_threads = set(channel.json_body["chunk"][room_id].keys()) + self.assertEqual(len(third_page_threads), 1) + + # Verify no overlap between any pages + self.assertEqual(len(first_page_threads & third_page_threads), 0) + self.assertEqual(len(second_page_threads & third_page_threads), 0) + + # Verify no gaps - all threads should be accounted for across all pages + all_threads = set(thread_ids) + combined_threads = first_page_threads | second_page_threads | third_page_threads + self.assertEqual( + combined_threads, + all_threads, + "Combined pages should include all thread updates with no gaps", + ) + + def test_invalid_dir_parameter(self) -> None: + """ + Test that forward pagination (dir=f) is rejected with an error. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Request with forward direction should fail + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 400) + + def test_invalid_limit_parameter(self) -> None: + """ + Test that invalid limit values are rejected. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Zero limit should fail + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 400) + + # Negative limit should fail + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 400) + + def test_invalid_pagination_tokens(self) -> None: + """ + Test that invalid from/to tokens are rejected with appropriate errors. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Invalid from token + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 400) + + # Invalid to token + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 400) + + def test_to_token_filtering(self) -> None: + """ + Test that the to_token parameter correctly limits pagination to updates + newer than the to_token (since we paginate backwards from newest to oldest). + This also verifies the to_token boundary is exclusive - updates at exactly + the to_token position should not be included (as they were already returned + in a previous response that synced up to that position). + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create two thread roots + thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[ + "event_id" + ] + thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[ + "event_id" + ] + + # Send replies to both threads + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread1_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 2", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread2_root_id, + }, + }, + tok=user1_tok, + ) + + # Request with limit=1 to get only the latest thread update + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200) + self.assertIn("next_batch", channel.json_body) + + # next_batch points to before the update we just received + next_batch = channel.json_body["next_batch"] + first_response_threads = set(channel.json_body["chunk"][room_id].keys()) + + # Request again with to=next_batch (lower bound for backward pagination) and no + # limit. + # This should get only the same thread updates as before, not the additional + # update. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200) + + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + # Should have exactly one thread update + self.assertEqual(len(chunk[room_id]), 1) + + second_response_threads = set(chunk[room_id].keys()) + + # Verify no overlap - the from parameter boundary should be exclusive + self.assertEqual( + first_response_threads, + second_response_threads, + "to parameter boundary should be exclusive - both responses should be identical", + ) + + def test_bundled_aggregations_on_thread_roots(self) -> None: + """ + Test that thread root events include bundled aggregations with latest thread event. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_id = self.helper.send(room_id, body="Thread root", tok=user1_tok)[ + "event_id" + ] + + # Send replies to create bundled aggregation data + for i in range(2): + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply {i + 1}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200) + + # Check that thread root has bundled aggregations with latest event + chunk = channel.json_body["chunk"] + thread_update = chunk[room_id][thread_root_id] + thread_root_event = thread_update["thread_root"] + + # Should have unsigned data with latest thread event content + self.assertIn("unsigned", thread_root_event) + self.assertIn("m.relations", thread_root_event["unsigned"]) + relations = thread_root_event["unsigned"]["m.relations"] + self.assertIn(RelationTypes.THREAD, relations) + + # Check latest event is present in bundled aggregations + thread_summary = relations[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event = thread_summary["latest_event"] + self.assertEqual(latest_event["content"]["body"], "Reply 2") + + def test_only_joined_rooms(self) -> None: + """ + Test that thread updates only include rooms where the user is currently joined. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create two rooms, user1 joins both + room1_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room2_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room2_id, user1_id, tok=user1_tok) + + # Create threads in both rooms + thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[ + "event_id" + ] + thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user2_tok)[ + "event_id" + ] + + # Add replies to both threads + self.helper.send_event( + room1_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread1_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room2_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 2", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread2_root_id, + }, + }, + tok=user2_tok, + ) + + # User1 leaves room2 + self.helper.leave(room2_id, user1_id, tok=user1_tok) + + # Request thread updates for user1 - should only get room1 + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200) + + chunk = channel.json_body["chunk"] + # Should only have room1, not room2 + self.assertIn(room1_id, chunk) + self.assertNotIn(room2_id, chunk) + self.assertIn(thread1_root_id, chunk[room1_id])