-
Notifications
You must be signed in to change notification settings - Fork 396
Add experimental support for MSC4360: Sliding Sync Threads Extension #19005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
cd4f422
6c460b3
4602b56
24b3873
9ef4ca1
79ea4be
6e69338
4dd82e5
4c51247
ab7e5a2
4d7826b
4cb0eea
c757969
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add experimental support for MSC4360: Sliding Sync Threads Extension. | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,8 +109,6 @@ async def get_relations( | |
) -> JsonDict: | ||
"""Get related events of a event, ordered by topological ordering. | ||
|
||
TODO Accept a PaginationConfig instead of individual pagination parameters. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been done already, the comment just wasn't updated. |
||
|
||
Args: | ||
requester: The user requesting the relations. | ||
event_id: Fetch events that relate to this event ID. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,12 +24,13 @@ | |
Optional, | ||
Sequence, | ||
Set, | ||
Tuple, | ||
cast, | ||
) | ||
|
||
from typing_extensions import TypeAlias, assert_never | ||
|
||
from synapse.api.constants import AccountDataTypes, EduTypes | ||
from synapse.api.constants import AccountDataTypes, EduTypes, RelationTypes | ||
from synapse.handlers.receipts import ReceiptEventSource | ||
from synapse.logging.opentracing import trace | ||
from synapse.storage.databases.main.receipts import ReceiptInRoom | ||
|
@@ -61,6 +62,7 @@ | |
_ThreadUnsubscription: TypeAlias = ( | ||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription | ||
) | ||
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate | ||
|
||
if TYPE_CHECKING: | ||
from synapse.server import HomeServer | ||
|
@@ -76,7 +78,9 @@ def __init__(self, hs: "HomeServer"): | |
self.event_sources = hs.get_event_sources() | ||
self.device_handler = hs.get_device_handler() | ||
self.push_rules_handler = hs.get_push_rules_handler() | ||
self.relations_handler = hs.get_relations_handler() | ||
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled | ||
self._enable_threads_ext = hs.config.experimental.msc4360_enabled | ||
|
||
@trace | ||
async def get_extensions_response( | ||
|
@@ -177,20 +181,32 @@ async def get_extensions_response( | |
from_token=from_token, | ||
) | ||
|
||
threads_coro = None | ||
if sync_config.extensions.threads is not None and self._enable_threads_ext: | ||
threads_coro = self.get_threads_extension_response( | ||
sync_config=sync_config, | ||
threads_request=sync_config.extensions.threads, | ||
actual_room_response_map=actual_room_response_map, | ||
to_token=to_token, | ||
from_token=from_token, | ||
) | ||
|
||
( | ||
to_device_response, | ||
e2ee_response, | ||
account_data_response, | ||
receipts_response, | ||
typing_response, | ||
thread_subs_response, | ||
threads_response, | ||
) = await gather_optional_coroutines( | ||
to_device_coro, | ||
e2ee_coro, | ||
account_data_coro, | ||
receipts_coro, | ||
typing_coro, | ||
thread_subs_coro, | ||
threads_coro, | ||
) | ||
|
||
return SlidingSyncResult.Extensions( | ||
|
@@ -200,6 +216,7 @@ async def get_extensions_response( | |
receipts=receipts_response, | ||
typing=typing_response, | ||
thread_subscriptions=thread_subs_response, | ||
threads=threads_response, | ||
) | ||
|
||
def find_relevant_room_ids_for_extension( | ||
|
@@ -970,3 +987,113 @@ async def get_thread_subscriptions_extension_response( | |
unsubscribed=unsubscribed_threads, | ||
prev_batch=prev_batch, | ||
) | ||
|
||
async def get_threads_extension_response( | ||
self, | ||
sync_config: SlidingSyncConfig, | ||
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension, | ||
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], | ||
to_token: StreamToken, | ||
from_token: Optional[SlidingSyncStreamToken], | ||
) -> Optional[SlidingSyncResult.Extensions.ThreadsExtension]: | ||
"""Handle Threads extension (MSC4360) | ||
|
||
Args: | ||
sync_config: Sync configuration. | ||
threads_request: The threads extension from the request. | ||
actual_room_response_map: A map of room ID to room results in the | ||
sliding sync response. Used to determine which threads already have | ||
events in the room timeline. | ||
to_token: The point in the stream to sync up to. | ||
from_token: The point in the stream to sync from. | ||
|
||
Returns: | ||
the response (None if empty or threads extension is disabled) | ||
""" | ||
if not threads_request.enabled: | ||
return None | ||
|
||
# Fetch thread updates globally across all joined rooms. | ||
# The database layer returns a StreamToken (exclusive) for prev_batch if there | ||
# are more results. | ||
( | ||
all_thread_updates, | ||
prev_batch_token, | ||
) = await self.store.get_thread_updates_for_user( | ||
user_id=sync_config.user.to_string(), | ||
from_token=from_token.stream_token.room_key if from_token else None, | ||
to_token=to_token.room_key, | ||
limit=threads_request.limit, | ||
include_thread_roots=threads_request.include_roots, | ||
) | ||
|
||
if len(all_thread_updates) == 0: | ||
return None | ||
|
||
# Identify which threads already have events in the room timelines. | ||
# If include_roots=False, we'll omit these threads from the extension response | ||
# since the client already sees the thread activity in the timeline. | ||
# If include_roots=True, we include all threads regardless, because the client | ||
# wants the thread root events. | ||
threads_in_timeline: Set[Tuple[str, str]] = set() # (room_id, thread_id) | ||
if not threads_request.include_roots: | ||
for room_id, room_result in actual_room_response_map.items(): | ||
if room_result.timeline_events: | ||
for event in room_result.timeline_events: | ||
# Check if this event is part of a thread | ||
relates_to = event.content.get("m.relates_to") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably should add these to (applies to fields below as well) |
||
if not isinstance(relates_to, dict): | ||
continue | ||
|
||
rel_type = relates_to.get("rel_type") | ||
|
||
# If this is a thread reply, track the thread | ||
if rel_type == RelationTypes.THREAD: | ||
thread_id = relates_to.get("event_id") | ||
Comment on lines
+1043
to
+1052
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have a helper for this kind of thing? Feels like something we'd be doing elsewhere already and could be wrapped up as a utility |
||
if thread_id: | ||
threads_in_timeline.add((room_id, thread_id)) | ||
|
||
# Collect thread root events and get bundled aggregations. | ||
# Only fetch bundled aggregations if we have thread root events to attach them to. | ||
thread_root_events = [ | ||
update.thread_root_event | ||
for update in all_thread_updates | ||
if update.thread_root_event | ||
] | ||
aggregations_map = {} | ||
if thread_root_events: | ||
aggregations_map = await self.relations_handler.get_bundled_aggregations( | ||
thread_root_events, | ||
sync_config.user.to_string(), | ||
) | ||
|
||
thread_updates: Dict[str, Dict[str, _ThreadUpdate]] = {} | ||
for update in all_thread_updates: | ||
# Skip this thread if it already has events in the room timeline | ||
# (unless include_roots=True, in which case we always include it) | ||
if (update.room_id, update.thread_id) in threads_in_timeline: | ||
continue | ||
Comment on lines
+1072
to
+1075
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we structure things to avoid fetching bundled aggregations for |
||
|
||
# Only look up bundled aggregations if we have a thread root event | ||
bundled_aggs = ( | ||
aggregations_map.get(update.thread_id) | ||
if update.thread_root_event | ||
else None | ||
) | ||
|
||
thread_updates.setdefault(update.room_id, {})[update.thread_id] = ( | ||
_ThreadUpdate( | ||
thread_root=update.thread_root_event, | ||
prev_batch=update.prev_batch, | ||
bundled_aggregations=bundled_aggs, | ||
) | ||
) | ||
|
||
# If after filtering we have no thread updates, return None to omit the extension | ||
if not thread_updates: | ||
return None | ||
|
||
return SlidingSyncResult.Extensions.ThreadsExtension( | ||
updates=thread_updates, | ||
prev_batch=prev_batch_token, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
from synapse.api.presence import UserPresenceState | ||
from synapse.api.ratelimiting import Ratelimiter | ||
from synapse.events.utils import ( | ||
EventClientSerializer, | ||
SerializeEventConfig, | ||
format_event_for_client_v2_without_room_id, | ||
format_event_raw, | ||
|
@@ -56,6 +57,7 @@ | |
from synapse.http.site import SynapseRequest | ||
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname | ||
from synapse.rest.admin.experimental_features import ExperimentalFeature | ||
from synapse.storage.databases.main import DataStore | ||
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken | ||
from synapse.types.rest.client import SlidingSyncBody | ||
from synapse.util.caches.lrucache import LruCache | ||
|
@@ -648,6 +650,7 @@ class SlidingSyncRestServlet(RestServlet): | |
- receipts (MSC3960) | ||
- account data (MSC3959) | ||
- thread subscriptions (MSC4308) | ||
- threads (MSC4360) | ||
|
||
Request query parameters: | ||
timeout: How long to wait for new events in milliseconds. | ||
|
@@ -851,7 +854,10 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |
logger.info("Client has disconnected; not serializing response.") | ||
return 200, {} | ||
|
||
response_content = await self.encode_response(requester, sliding_sync_results) | ||
time_now = self.clock.time_msec() | ||
response_content = await self.encode_response( | ||
requester, sliding_sync_results, time_now | ||
) | ||
Comment on lines
+857
to
+860
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call on using the same (consistency) |
||
|
||
return 200, response_content | ||
|
||
|
@@ -860,6 +866,7 @@ async def encode_response( | |
self, | ||
requester: Requester, | ||
sliding_sync_result: SlidingSyncResult, | ||
time_now: int, | ||
) -> JsonDict: | ||
response: JsonDict = defaultdict(dict) | ||
|
||
|
@@ -868,10 +875,10 @@ async def encode_response( | |
if serialized_lists: | ||
response["lists"] = serialized_lists | ||
response["rooms"] = await self.encode_rooms( | ||
requester, sliding_sync_result.rooms | ||
requester, sliding_sync_result.rooms, time_now | ||
) | ||
response["extensions"] = await self.encode_extensions( | ||
requester, sliding_sync_result.extensions | ||
requester, sliding_sync_result.extensions, time_now | ||
) | ||
|
||
return response | ||
|
@@ -903,9 +910,8 @@ async def encode_rooms( | |
self, | ||
requester: Requester, | ||
rooms: Dict[str, SlidingSyncResult.RoomResult], | ||
time_now: int, | ||
) -> JsonDict: | ||
time_now = self.clock.time_msec() | ||
|
||
serialize_options = SerializeEventConfig( | ||
event_format=format_event_for_client_v2_without_room_id, | ||
requester=requester, | ||
|
@@ -1021,7 +1027,10 @@ async def encode_rooms( | |
|
||
@trace_with_opname("sliding_sync.encode_extensions") | ||
async def encode_extensions( | ||
self, requester: Requester, extensions: SlidingSyncResult.Extensions | ||
self, | ||
requester: Requester, | ||
extensions: SlidingSyncResult.Extensions, | ||
time_now: int, | ||
) -> JsonDict: | ||
serialized_extensions: JsonDict = {} | ||
|
||
|
@@ -1091,6 +1100,17 @@ async def encode_extensions( | |
_serialise_thread_subscriptions(extensions.thread_subscriptions) | ||
) | ||
|
||
# excludes both None and falsy `threads` | ||
if extensions.threads: | ||
serialized_extensions[ | ||
"io.element.msc4360.threads" | ||
] = await _serialise_threads( | ||
self.event_serializer, | ||
time_now, | ||
extensions.threads, | ||
self.store, | ||
) | ||
|
||
return serialized_extensions | ||
|
||
|
||
|
@@ -1127,6 +1147,72 @@ def _serialise_thread_subscriptions( | |
return out | ||
|
||
|
||
async def _serialise_threads( | ||
event_serializer: EventClientSerializer, | ||
time_now: int, | ||
threads: SlidingSyncResult.Extensions.ThreadsExtension, | ||
store: "DataStore", | ||
) -> JsonDict: | ||
""" | ||
Serialize the threads extension response for sliding sync. | ||
|
||
Args: | ||
event_serializer: The event serializer to use for serializing thread root events. | ||
time_now: The current time in milliseconds, used for event serialization. | ||
threads: The threads extension data containing thread updates and pagination tokens. | ||
store: The datastore, needed for serializing stream tokens. | ||
|
||
Returns: | ||
A JSON-serializable dict containing: | ||
- "updates": A nested dict mapping room_id -> thread_root_id -> thread update. | ||
Each thread update may contain: | ||
- "thread_root": The serialized thread root event (if include_roots was True), | ||
with bundled aggregations including the latest_event in unsigned.m.relations.m.thread. | ||
- "prev_batch": A pagination token for fetching older events in the thread. | ||
- "prev_batch": A pagination token for fetching older thread updates (if available). | ||
""" | ||
out: JsonDict = {} | ||
|
||
if threads.updates: | ||
updates_dict: JsonDict = {} | ||
for room_id, thread_updates in threads.updates.items(): | ||
room_updates: JsonDict = {} | ||
for thread_root_id, update in thread_updates.items(): | ||
# Serialize the update | ||
update_dict: JsonDict = {} | ||
|
||
# Serialize the thread_root event if present | ||
if update.thread_root is not None: | ||
# Create a mapping of event_id to bundled_aggregations | ||
bundle_aggs_map = ( | ||
{thread_root_id: update.bundled_aggregations} | ||
if update.bundled_aggregations | ||
else None | ||
) | ||
serialized_events = await event_serializer.serialize_events( | ||
[update.thread_root], | ||
time_now, | ||
bundle_aggregations=bundle_aggs_map, | ||
) | ||
if serialized_events: | ||
update_dict["thread_root"] = serialized_events[0] | ||
|
||
# Add prev_batch if present | ||
if update.prev_batch is not None: | ||
update_dict["prev_batch"] = await update.prev_batch.to_string(store) | ||
|
||
room_updates[thread_root_id] = update_dict | ||
|
||
updates_dict[room_id] = room_updates | ||
|
||
out["updates"] = updates_dict | ||
|
||
if threads.prev_batch: | ||
out["prev_batch"] = await threads.prev_batch.to_string(store) | ||
|
||
return out | ||
|
||
|
||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | ||
SyncRestServlet(hs).register(http_server) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also comments on the MSC that may change how we approach things here -> matrix-org/matrix-spec-proposals#4360 (review)