Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/19041.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add companion endpoint for MSC4360: Sliding Sync Threads Extension.
196 changes: 191 additions & 5 deletions synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@

import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple

from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.api.errors import SynapseError
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.relations import BundledAggregations, ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict
from synapse.storage.databases.main.relations import ThreadsNextBatch, ThreadUpdateInfo
from synapse.streams.config import (
PaginationConfig,
extract_stream_token_from_pagination_token,
)
from synapse.types import JsonDict, RoomStreamToken, StreamToken

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -133,6 +138,187 @@ async def on_GET(
return 200, result


class ThreadUpdatesServlet(RestServlet):
"""
Companion endpoint to the Sliding Sync threads extension (MSC4360).
Allows clients to bulk fetch thread updates across all joined rooms.
"""

PATTERNS = client_patterns(
"/io.element.msc4360/thread_updates$",
unstable=True,
releases=(),
)
CATEGORY = "Client API requests"

def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.relations_handler = hs.get_relations_handler()
self.event_serializer = hs.get_event_client_serializer()

async def _serialize_thread_updates(
self,
thread_updates: Sequence[ThreadUpdateInfo],
bundled_aggregations: Dict[str, BundledAggregations],
time_now: int,
serialize_options: SerializeEventConfig,
) -> Dict[str, Dict[str, JsonDict]]:
"""
Serialize thread updates into the response format.

Args:
thread_updates: List of thread update info from storage
bundled_aggregations: Map of event_id to bundled aggregations
time_now: Current time in milliseconds
serialize_options: Serialization configuration

Returns:
Nested dict mapping room_id -> thread_root_id -> thread update dict
"""
chunk: Dict[str, Dict[str, JsonDict]] = {}

for update in thread_updates:
room_id = update.room_id
thread_id = update.thread_id

if room_id not in chunk:
chunk[room_id] = {}

update_dict: JsonDict = {}

# Serialize thread root if present
if update.thread_root_event is not None:
bundle_aggs_map = (
{thread_id: bundled_aggregations[thread_id]}
if thread_id in bundled_aggregations
else None
)
serialized_events = await self.event_serializer.serialize_events(
[update.thread_root_event],
time_now,
config=serialize_options,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]

# Add per-thread prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(
self.store
)

chunk[room_id][thread_id] = update_dict

return chunk

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

# Parse parameters
dir_str = parse_string(request, "dir", default="b")
if dir_str != "b":
raise SynapseError(
400,
"The 'dir' parameter must be 'b' (backward). Forward pagination is not supported.",
)

limit = parse_integer(request, "limit", default=100)
if limit <= 0:
raise SynapseError(400, "The 'limit' parameter must be positive.")

from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")

# Parse pagination tokens
from_token: Optional[RoomStreamToken] = None
to_token: Optional[RoomStreamToken] = None

if from_token_str:
try:
stream_token_str = extract_stream_token_from_pagination_token(
from_token_str
)
stream_token = await StreamToken.from_string(
self.store, stream_token_str
)
from_token = stream_token.room_key
except Exception:
raise SynapseError(400, "'from' parameter is invalid")

if to_token_str:
try:
stream_token_str = extract_stream_token_from_pagination_token(
to_token_str
)
stream_token = await StreamToken.from_string(
self.store, stream_token_str
)
to_token = stream_token.room_key
except Exception:
raise SynapseError(400, "'to' parameter is invalid")

# Fetch thread updates from storage
# For backward pagination:
# - 'from' (upper bound, exclusive) maps to 'to_token' (inclusive with <=)
# Since next_batch is (last_returned - 1), <= excludes the last returned item
# - 'to' (lower bound, exclusive) maps to 'from_token' (exclusive with >)
thread_updates, next_token = await self.store.get_thread_updates_for_user(
user_id=requester.user.to_string(),
from_token=to_token,
to_token=from_token,
limit=limit,
include_thread_roots=True,
)

# Serialize response
chunk: Dict[str, Dict[str, JsonDict]] = {}

if thread_updates:
# Get bundled aggregations for all thread roots
thread_root_events = [
update.thread_root_event
for update in thread_updates
if update.thread_root_event is not None
]

bundled_aggregations = {}
if thread_root_events:
bundled_aggregations = (
await self.relations_handler.get_bundled_aggregations(
thread_root_events, requester.user.to_string()
)
)

# Set up serialization
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(
requester=requester,
)

# Serialize all thread updates
chunk = await self._serialize_thread_updates(
thread_updates=thread_updates,
bundled_aggregations=bundled_aggregations,
time_now=time_now,
serialize_options=serialize_options,
)

# Build response
response: JsonDict = {"chunk": chunk}

# Add next_batch token for pagination
if next_token is not None:
response["next_batch"] = await next_token.to_string(self.store)

return 200, response


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
ThreadsServlet(hs).register(http_server)
if hs.config.experimental.msc4360_enabled:
ThreadUpdatesServlet(hs).register(http_server)
54 changes: 32 additions & 22 deletions synapse/streams/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@
MAX_LIMIT = 1000


def extract_stream_token_from_pagination_token(token_str: str) -> str:
"""
Extract the StreamToken portion from a pagination token string.

Handles both:
- StreamToken format: "s123_456_..."
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)

This allows clients using sliding sync to use their pos tokens
with endpoints like /relations and /messages.

Args:
token_str: The token string to parse

Returns:
The StreamToken portion of the token
"""
if "/" in token_str:
# SlidingSyncStreamToken format: "connection_position/stream_token"
# Split and return just the stream_token part
parts = token_str.split("/", 1)
if len(parts) == 2:
return parts[1]
return token_str


@attr.s(slots=True, auto_attribs=True)
class PaginationConfig:
"""A configuration object which stores pagination parameters."""
Expand All @@ -58,40 +84,24 @@ async def from_request(
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")

# Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format
def extract_stream_token(token_str: str) -> str:
"""
Extract the StreamToken portion from a token string.

Handles both:
- StreamToken format: "s123_456_..."
- SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /)

This allows clients using sliding sync to use their pos tokens
with endpoints like /relations and /messages.
"""
if "/" in token_str:
# SlidingSyncStreamToken format: "connection_position/stream_token"
# Split and return just the stream_token part
parts = token_str.split("/", 1)
if len(parts) == 2:
return parts[1]
return token_str

try:
from_tok = None
if from_tok_str == "END":
from_tok = None # For backwards compat.
elif from_tok_str:
stream_token_str = extract_stream_token(from_tok_str)
stream_token_str = extract_stream_token_from_pagination_token(
from_tok_str
)
from_tok = await StreamToken.from_string(store, stream_token_str)
except Exception:
raise SynapseError(400, "'from' parameter is invalid")

try:
to_tok = None
if to_tok_str:
stream_token_str = extract_stream_token(to_tok_str)
stream_token_str = extract_stream_token_from_pagination_token(
to_tok_str
)
to_tok = await StreamToken.from_string(store, stream_token_str)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")
Expand Down
Loading
Loading