diff --git a/changelog.d/18581.feature b/changelog.d/18581.feature new file mode 100644 index 00000000000..f74a80ea3ed --- /dev/null +++ b/changelog.d/18581.feature @@ -0,0 +1 @@ +Temporary for CI to pass. diff --git a/docker/complement/conf/start_for_complement.sh b/docker/complement/conf/start_for_complement.sh index a5e06396e26..11f46618426 100755 --- a/docker/complement/conf/start_for_complement.sh +++ b/docker/complement/conf/start_for_complement.sh @@ -65,6 +65,7 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then client_reader, \ appservice, \ pusher, \ + device_lists:2, \ stream_writers=account_data+presence+receipts+to_device+typing" fi diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 102a88fad14..c928f18e5a2 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -306,6 +306,15 @@ "shared_extra_conf": {}, "worker_extra_conf": "", }, + "device_lists": { + "app": "synapse.app.generic_worker", + "listener_resources": ["client", "replication"], + "endpoint_patterns": [ + "^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$" + ], + "shared_extra_conf": {}, + "worker_extra_conf": "", + }, "typing": { "app": "synapse.app.generic_worker", "listener_resources": ["client", "replication"], @@ -412,16 +421,17 @@ def add_worker_roles_to_shared_config( # streams instance_map = shared_config.setdefault("instance_map", {}) - # This is a list of the stream_writers that there can be only one of. Events can be - # sharded, and therefore doesn't belong here. - singular_stream_writers = [ + # This is a list of the stream_writers. + stream_writers = { "account_data", + "events", + "device_lists", "presence", "receipts", "to_device", "typing", "push_rules", - ] + } # Worker-type specific sharding config. Now a single worker can fulfill multiple # roles, check each. @@ -434,25 +444,13 @@ def add_worker_roles_to_shared_config( if "event_persister" in worker_types_set: # Event persisters write to the events stream, so we need to update # the list of event stream writers - shared_config.setdefault("stream_writers", {}).setdefault("events", []).append( - worker_name - ) + worker_types_set.add("events") - # Map of stream writer instance names to host/ports combos - if os.environ.get("SYNAPSE_USE_UNIX_SOCKET", False): - instance_map[worker_name] = { - "path": f"/run/worker.{worker_port}", - } - else: - instance_map[worker_name] = { - "host": "localhost", - "port": worker_port, - } # Update the list of stream writers. It's convenient that the name of the worker # type is the same as the stream to write. Iterate over the whole list in case there # is more than one. for worker in worker_types_set: - if worker in singular_stream_writers: + if worker in stream_writers: shared_config.setdefault("stream_writers", {}).setdefault( worker, [] ).append(worker_name) diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 257ea4a1a25..3e2858a007a 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -4291,6 +4291,8 @@ This setting has the following sub-options: * `push_rules` (string): Name of a worker assigned to the `push_rules` stream. +* `device_lists` (string): Name of a worker assigned to the `device_lists` stream. + Example configuration: ```yaml stream_writers: diff --git a/schema/synapse-config.schema.yaml b/schema/synapse-config.schema.yaml index e4dcd36c305..4973ef626f0 100644 --- a/schema/synapse-config.schema.yaml +++ b/schema/synapse-config.schema.yaml @@ -5323,6 +5323,9 @@ properties: push_rules: type: string description: Name of a worker assigned to the `push_rules` stream. + device_lists: + type: string + description: Name of a worker assigned to the `device_lists` stream. default: {} examples: - events: worker1 diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 2dfeb47c2ee..94cdeb8c949 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -158,12 +158,12 @@ class WriterLocations: can only be a single instance. account_data: The instances that write to the account data streams. Currently can only be a single instance. - receipts: The instances that write to the receipts stream. Currently - can only be a single instance. + receipts: The instances that write to the receipts stream. presence: The instances that write to the presence stream. Currently can only be a single instance. push_rules: The instances that write to the push stream. Currently can only be a single instance. + device_lists: The instances that write to the device list stream. """ events: List[str] = attr.ib( @@ -194,6 +194,10 @@ class WriterLocations: default=["master"], converter=_instance_to_list_converter, ) + device_lists: List[str] = attr.ib( + default=["master"], + converter=_instance_to_list_converter, + ) @attr.s(auto_attribs=True) @@ -415,6 +419,11 @@ def read_config( "Must only specify one instance to handle `push` messages." ) + if len(self.writers.device_lists) == 0: + raise ConfigError( + "Must specify at least one instance to handle `device_lists` messages." + ) + self.events_shard_config = RoutableShardedWorkerHandlingConfig( self.writers.events ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index f3bbdb5a05b..0cfc49bced9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -631,7 +631,8 @@ async def _get_device_list_summary( # Fetch the users who have modified their device list since then. users_with_changed_device_lists = await self.store.get_all_devices_changed( - from_key, to_key=new_key + MultiWriterStreamToken(stream=from_key), + to_key=MultiWriterStreamToken(stream=new_key), ) # Filter out any users the application service is not interested in diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 42e53d920ab..8d4d84bed10 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -24,7 +24,6 @@ from synapse.api.constants import Membership from synapse.api.errors import SynapseError -from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -84,10 +83,6 @@ async def deactivate_account( Returns: True if identity server supports removing threepids, otherwise False. """ - - # This can only be called on the main process. - assert isinstance(self._device_handler, DeviceHandler) - # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 8f9bf92fda3..97a9ffe84e6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -20,6 +20,7 @@ # # import logging +import random from threading import Lock from typing import ( TYPE_CHECKING, @@ -31,6 +32,7 @@ Optional, Set, Tuple, + cast, ) from synapse.api import errors @@ -48,6 +50,13 @@ run_as_background_process, wrap_as_background_process, ) +from synapse.replication.http.devices import ( + ReplicationDeviceHandleRoomUnPartialStated, + ReplicationHandleNewDeviceUpdateRestServlet, + ReplicationMultiUserDevicesResyncRestServlet, + ReplicationNotifyDeviceUpdateRestServlet, + ReplicationNotifyUserSignatureUpdateRestServlet, +) from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.databases.main.state_deltas import StateDelta @@ -75,6 +84,8 @@ ) if TYPE_CHECKING: + from synapse.app.generic_worker import GenericWorkerStore + from synapse.app.homeserver import DataStore from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -84,18 +95,38 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 +def _check_device_name_length(name: Optional[str]) -> None: + """ + Checks whether a device name is longer than the maximum allowed length. + + Args: + name: The name of the device. + + Raises: + SynapseError: if the device name is too long. + """ + if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN: + raise SynapseError( + 400, + "Device display name is too long (max %i)" % (MAX_DEVICE_DISPLAY_NAME_LEN,), + errcode=Codes.TOO_LARGE, + ) + + class DeviceWorkerHandler: device_list_updater: "DeviceListWorkerUpdater" + store: "GenericWorkerStore" def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs - self.store = hs.get_datastores().main + self.store = cast("GenericWorkerStore", hs.get_datastores().main) self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self._appservice_handler = hs.get_application_service_handler() self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() + self._account_data_handler = hs.get_account_data_handler() self._event_sources = hs.get_event_sources() self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -104,12 +135,233 @@ def __init__(self, hs: "HomeServer"): ) self._task_scheduler = hs.get_task_scheduler() + self._dont_notify_new_devices_for = ( + hs.config.registration.dont_notify_new_devices_for + ) + self.device_list_updater = DeviceListWorkerUpdater(hs) self._task_scheduler.register_action( self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME ) + self._device_list_writers = hs.config.worker.writers.device_lists + + # There are a few things we only handle on a single writer because of + # linearizers in the DeviceListUpdater + self._main_device_list_writer = hs.config.worker.writers.device_lists[0] + + self._notify_device_update_client = ( + ReplicationNotifyDeviceUpdateRestServlet.make_client(hs) + ) + self._notify_user_signature_update_client = ( + ReplicationNotifyUserSignatureUpdateRestServlet.make_client(hs) + ) + self._handle_new_device_update_client = ( + ReplicationHandleNewDeviceUpdateRestServlet.make_client(hs) + ) + self._handle_room_un_partial_stated_client = ( + ReplicationDeviceHandleRoomUnPartialStated.make_client(hs) + ) + + hs.get_federation_registry().register_instances_for_edu( + EduTypes.DEVICE_LIST_UPDATE, + [self._main_device_list_writer], + ) + + async def check_device_registered( + self, + user_id: str, + device_id: Optional[str], + initial_device_display_name: Optional[str] = None, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, + ) -> str: + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id: @user:id + device_id: device id supplied by client + initial_device_display_name: device display name from client + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from the SSO IdP. + Returns: + device id (generated if none was supplied) + """ + + _check_device_name_length(initial_device_display_name) + + # Check if we should send out device lists updates for this new device. + notify = user_id not in self._dont_notify_new_devices_for + + if device_id is not None: + new_device = await self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + if new_device: + if notify: + await self.notify_device_update(user_id, [device_id]) + return device_id + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + new_device_id = stringutils.random_string(10).upper() + new_device = await self.store.store_device( + user_id=user_id, + device_id=new_device_id, + initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + if new_device: + if notify: + await self.notify_device_update(user_id, [new_device_id]) + return new_device_id + attempts += 1 + + raise errors.StoreError(500, "Couldn't generate a device ID.") + + @trace + async def delete_all_devices_for_user( + self, user_id: str, except_device_id: Optional[str] = None + ) -> None: + """Delete all of the user's devices + + Args: + user_id: The user to remove all devices from + except_device_id: optional device id which should not be deleted + """ + device_map = await self.store.get_devices_by_user(user_id) + device_ids = list(device_map) + if except_device_id is not None: + device_ids = [d for d in device_ids if d != except_device_id] + await self.delete_devices(user_id, device_ids) + + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + """Delete several devices + + Args: + user_id: The user to delete devices from. + device_ids: The list of device IDs to delete + """ + to_device_stream_id = self._event_sources.get_current_token().to_device_key + + try: + await self.store.delete_devices(user_id, device_ids) + except errors.StoreError as e: + if e.code == 404: + # no match + set_tag("error", True) + set_tag("reason", "User doesn't have that device id.") + else: + raise + + # Delete data specific to each device. Not optimised as it is not + # considered as part of a critical path. + for device_id in device_ids: + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) + await self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=device_id + ) + + if self.hs.config.experimental.msc3890_enabled: + # Remove any local notification settings for this device in accordance + # with MSC3890. + await self._account_data_handler.remove_account_data_for_user( + user_id, + f"org.matrix.msc3890.local_notification_settings.{device_id}", + ) + + # Delete device messages asynchronously and in batches using the task scheduler + # We specify an upper stream id to avoid deleting non delivered messages + # if an user re-uses a device ID. + await self._task_scheduler.schedule_task( + DELETE_DEVICE_MSGS_TASK_NAME, + resource_id=device_id, + params={ + "user_id": user_id, + "device_id": device_id, + "up_to_stream_id": to_device_stream_id, + }, + ) + + # Pushers are deleted after `delete_access_tokens_for_user` is called so that + # modules using `on_logged_out` hook can use them if needed. + await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) + + await self.notify_device_update(user_id, device_ids) + + async def upsert_device( + self, user_id: str, device_id: str, display_name: Optional[str] = None + ) -> bool: + """Create or update a device + + Args: + user_id: The user to update devices of. + device_id: The device to update. + display_name: The new display name for this device. + + Returns: + True if the device was created, False if it was updated. + + """ + + # Reject a new displayname which is too long. + _check_device_name_length(display_name) + + created = await self.store.store_device( + user_id, + device_id, + initial_device_display_name=display_name, + ) + + if not created: + await self.store.update_device( + user_id, + device_id, + new_display_name=display_name, + ) + + await self.notify_device_update(user_id, [device_id]) + return created + + async def update_device(self, user_id: str, device_id: str, content: dict) -> None: + """Update the given device + + Args: + user_id: The user to update devices of. + device_id: The device to update. + content: body of update request + """ + + # Reject a new displayname which is too long. + new_display_name = content.get("display_name") + + _check_device_name_length(new_display_name) + + try: + await self.store.update_device( + user_id, device_id, new_display_name=new_display_name + ) + await self.notify_device_update(user_id, [device_id]) + except errors.StoreError as e: + if e.code == 404: + raise errors.NotFoundError() + else: + raise + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -139,12 +391,108 @@ async def get_dehydrated_device( """Retrieve the information for a dehydrated device. Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + + async def store_dehydrated_device( + self, + user_id: str, + device_id: Optional[str], + device_data: JsonDict, + initial_device_display_name: Optional[str] = None, + keys_for_device: Optional[JsonDict] = None, + ) -> str: + """Store a dehydrated device for a user, optionally storing the keys associated with + it as well. If the user had a previous dehydrated device, it is removed. + + Args: + user_id: the user that we are storing the device for + device_id: device id supplied by client + device_data: the dehydrated device information + initial_device_display_name: The display name to use for the device + keys_for_device: keys for the dehydrated device + Returns: + device id of the dehydrated device + """ + device_id = await self.check_device_registered( + user_id, + device_id, + initial_device_display_name, + ) + + time_now = self.clock.time_msec() + + old_device_id = await self.store.store_dehydrated_device( + user_id, device_id, device_data, time_now, keys_for_device + ) + + if old_device_id is not None: + await self.delete_devices(user_id, [old_device_id]) + + return device_id + + async def rehydrate_device( + self, user_id: str, access_token: str, device_id: str + ) -> dict: + """Process a rehydration request from the user. + + Args: + user_id: the user who is rehydrating the device + access_token: the access token used for the request + device_id: the ID of the device that will be rehydrated + Returns: + a dict containing {"success": True} + """ + success = await self.store.remove_dehydrated_device(user_id, device_id) + + if not success: + raise errors.NotFoundError() + + # If the dehydrated device was successfully deleted (the device ID + # matched the stored dehydrated device), then modify the access + # token and refresh token to use the dehydrated device's ID and + # copy the old device display name to the dehydrated device, + # and destroy the old device ID + old_device_id = await self.store.set_device_for_access_token( + access_token, device_id + ) + await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id) + old_device = await self.store.get_device(user_id, old_device_id) + if old_device is None: + raise errors.NotFoundError() + await self.store.update_device(user_id, device_id, old_device["display_name"]) + # can't call self.delete_device because that will clobber the + # access token so call the storage layer directly + await self.store.delete_devices(user_id, [old_device_id]) + await self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=old_device_id + ) + + # tell everyone that the old device is gone and that the dehydrated + # device has a new display name + await self.notify_device_update(user_id, [old_device_id, device_id]) + + return {"success": True} + + async def delete_dehydrated_device(self, user_id: str, device_id: str) -> None: + """ + Delete a stored dehydrated device. + + Args: + user_id: the user_id to delete the device from + device_id: id of the dehydrated device to delete """ - return await self.store.get_dehydrated_device(user_id) + success = await self.store.remove_dehydrated_device(user_id, device_id) + + if not success: + raise errors.NotFoundError() + + await self.delete_devices(user_id, [device_id]) + await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: @@ -484,10 +832,52 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None: gone from partial to full state. """ - # TODO(faster_joins): worker mode support - # https://github.com/matrix-org/synapse/issues/12994 - logger.error( - "Trying handling device list state for partial join: not supported on workers." + await self._handle_room_un_partial_stated_client( + instance_name=random.choice(self._device_list_writers), + room_id=room_id, + ) + + @trace + @measure_func("notify_device_update") + async def notify_device_update( + self, user_id: str, device_ids: StrCollection + ) -> None: + """Notify that a user's device(s) has changed. Pokes the notifier, and + remote servers if the user is local. + + Args: + user_id: The Matrix ID of the user who's device list has been updated. + device_ids: The device IDs that have changed. + """ + await self._notify_device_update_client( + instance_name=random.choice(self._device_list_writers), + user_id=user_id, + device_ids=device_ids, + ) + + async def notify_user_signature_update( + self, + from_user_id: str, + user_ids: List[str], + ) -> None: + """Notify a device writer that a user have made new signatures of other users. + + Args: + from_user_id: The Matrix ID of the user who's signatures have been updated. + user_ids: The Matrix IDs of the users that have changed. + """ + await self._notify_user_signature_update_client( + instance_name=random.choice(self._device_list_writers), + from_user_id=from_user_id, + user_ids=user_ids, + ) + + async def handle_new_device_update(self) -> None: + """Wake up a device writer to send local device list changes as federation outbound pokes.""" + # This is only sent to the first device writer to avoid cross-worker + # locks in _handle_new_device_update_async. + await self._handle_new_device_update_client( + instance_name=self._device_list_writers[0], ) DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000 @@ -521,27 +911,25 @@ async def _delete_device_messages( class DeviceHandler(DeviceWorkerHandler): - device_list_updater: "DeviceListUpdater" + store: "DataStore" # type: ignore[assignment] def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.federation_sender = hs.get_federation_sender() - self._account_data_handler = hs.get_account_data_handler() - self._storage_controllers = hs.get_storage_controllers() - self.db_pool = hs.get_datastores().main.db_pool - - self._dont_notify_new_devices_for = ( - hs.config.registration.dont_notify_new_devices_for - ) - - self.device_list_updater = DeviceListUpdater(hs, self) + # We only need to poke the federation sender explicitly if its on the + # same instance. Other federation sender instances will get notified by + # `synapse.app.generic_worker.FederationSenderHandler` when it sees it + # in the device lists stream. + self.federation_sender = None + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() - federation_registry = hs.get_federation_registry() + self._storage_controllers = hs.get_storage_controllers() - federation_registry.register_edu_handler( - EduTypes.DEVICE_LIST_UPDATE, - self.device_list_updater.incoming_device_list_update, + # There are a few things that are only handled on the main device list + # writer to avoid cross-worker locks + self._is_main_device_list_writer = ( + hs.get_instance_name() == self._main_device_list_writer ) # Whether `_handle_new_device_update_async` is currently processing. @@ -551,15 +939,21 @@ def __init__(self, hs: "HomeServer"): # processing. self._handle_new_device_update_new_data = False - # On start up check if there are any updates pending. - hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + if self._main_device_list_writer: + # On start up check if there are any updates pending. + hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + self.device_list_updater = DeviceListUpdater(hs, self) + hs.get_federation_registry().register_edu_handler( + EduTypes.DEVICE_LIST_UPDATE, + self.device_list_updater.incoming_device_list_update, + ) self._delete_stale_devices_after = hs.config.server.delete_stale_devices_after - # Ideally we would run this on a worker and condition this on the - # "run_background_tasks_on" setting, but this would mean making the notification - # of device list changes over federation work on workers, which is nontrivial. - if self._delete_stale_devices_after is not None: + if ( + hs.config.worker.run_background_tasks + and self._delete_stale_devices_after is not None + ): self.clock.looping_call( run_as_background_process, DELETE_STALE_DEVICES_INTERVAL_MS, @@ -567,229 +961,18 @@ def __init__(self, hs: "HomeServer"): self._delete_stale_devices, ) - def _check_device_name_length(self, name: Optional[str]) -> None: - """ - Checks whether a device name is longer than the maximum allowed length. - - Args: - name: The name of the device. - - Raises: - SynapseError: if the device name is too long. - """ - if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN: - raise SynapseError( - 400, - "Device display name is too long (max %i)" - % (MAX_DEVICE_DISPLAY_NAME_LEN,), - errcode=Codes.TOO_LARGE, - ) - - async def check_device_registered( - self, - user_id: str, - device_id: Optional[str], - initial_device_display_name: Optional[str] = None, - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, - ) -> str: - """ - If the given device has not been registered, register it with the - supplied display name. - - If no device_id is supplied, we make one up. - - Args: - user_id: @user:id - device_id: device id supplied by client - initial_device_display_name: device display name from client - auth_provider_id: The SSO IdP the user used, if any. - auth_provider_session_id: The session ID (sid) got from the SSO IdP. - Returns: - device id (generated if none was supplied) - """ - - self._check_device_name_length(initial_device_display_name) - - # Check if we should send out device lists updates for this new device. - notify = user_id not in self._dont_notify_new_devices_for - - if device_id is not None: - new_device = await self.store.store_device( - user_id=user_id, - device_id=device_id, - initial_device_display_name=initial_device_display_name, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, - ) - if new_device: - if notify: - await self.notify_device_update(user_id, [device_id]) - return device_id - - # if the device id is not specified, we'll autogen one, but loop a few - # times in case of a clash. - attempts = 0 - while attempts < 5: - new_device_id = stringutils.random_string(10).upper() - new_device = await self.store.store_device( - user_id=user_id, - device_id=new_device_id, - initial_device_display_name=initial_device_display_name, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, - ) - if new_device: - if notify: - await self.notify_device_update(user_id, [new_device_id]) - return new_device_id - attempts += 1 - - raise errors.StoreError(500, "Couldn't generate a device ID.") - async def _delete_stale_devices(self) -> None: """Background task that deletes devices which haven't been accessed for more than a configured time period. """ - # We should only be running this job if the config option is defined. - assert self._delete_stale_devices_after is not None - now_ms = self.clock.time_msec() - since_ms = now_ms - self._delete_stale_devices_after - devices = await self.store.get_local_devices_not_accessed_since(since_ms) - - for user_id, user_devices in devices.items(): - await self.delete_devices(user_id, user_devices) - - @trace - async def delete_all_devices_for_user( - self, user_id: str, except_device_id: Optional[str] = None - ) -> None: - """Delete all of the user's devices - - Args: - user_id: The user to remove all devices from - except_device_id: optional device id which should not be deleted - """ - device_map = await self.store.get_devices_by_user(user_id) - device_ids = list(device_map) - if except_device_id is not None: - device_ids = [d for d in device_ids if d != except_device_id] - await self.delete_devices(user_id, device_ids) - - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: - """Delete several devices - - Args: - user_id: The user to delete devices from. - device_ids: The list of device IDs to delete - """ - to_device_stream_id = self._event_sources.get_current_token().to_device_key - - try: - await self.store.delete_devices(user_id, device_ids) - except errors.StoreError as e: - if e.code == 404: - # no match - set_tag("error", True) - set_tag("reason", "User doesn't have that device id.") - else: - raise - - # Delete data specific to each device. Not optimised as it is not - # considered as part of a critical path. - for device_id in device_ids: - await self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) - await self.store.delete_e2e_keys_by_device( - user_id=user_id, device_id=device_id - ) - - if self.hs.config.experimental.msc3890_enabled: - # Remove any local notification settings for this device in accordance - # with MSC3890. - await self._account_data_handler.remove_account_data_for_user( - user_id, - f"org.matrix.msc3890.local_notification_settings.{device_id}", - ) - - # Delete device messages asynchronously and in batches using the task scheduler - # We specify an upper stream id to avoid deleting non delivered messages - # if an user re-uses a device ID. - await self._task_scheduler.schedule_task( - DELETE_DEVICE_MSGS_TASK_NAME, - resource_id=device_id, - params={ - "user_id": user_id, - "device_id": device_id, - "up_to_stream_id": to_device_stream_id, - }, - ) - - # Pushers are deleted after `delete_access_tokens_for_user` is called so that - # modules using `on_logged_out` hook can use them if needed. - await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) - - await self.notify_device_update(user_id, device_ids) - - async def upsert_device( - self, user_id: str, device_id: str, display_name: Optional[str] = None - ) -> bool: - """Create or update a device - - Args: - user_id: The user to update devices of. - device_id: The device to update. - display_name: The new display name for this device. - - Returns: - True if the device was created, False if it was updated. - - """ - - # Reject a new displayname which is too long. - self._check_device_name_length(display_name) - - created = await self.store.store_device( - user_id, - device_id, - initial_device_display_name=display_name, - ) - - if not created: - await self.store.update_device( - user_id, - device_id, - new_display_name=display_name, - ) - - await self.notify_device_update(user_id, [device_id]) - return created - - async def update_device(self, user_id: str, device_id: str, content: dict) -> None: - """Update the given device - - Args: - user_id: The user to update devices of. - device_id: The device to update. - content: body of update request - """ - - # Reject a new displayname which is too long. - new_display_name = content.get("display_name") - - self._check_device_name_length(new_display_name) - - try: - await self.store.update_device( - user_id, device_id, new_display_name=new_display_name - ) - await self.notify_device_update(user_id, [device_id]) - except errors.StoreError as e: - if e.code == 404: - raise errors.NotFoundError() - else: - raise + # We should only be running this job if the config option is defined. + assert self._delete_stale_devices_after is not None + now_ms = self.clock.time_msec() + since_ms = now_ms - self._delete_stale_devices_after + devices = await self.store.get_local_devices_not_accessed_since(since_ms) + + for user_id, user_devices in devices.items(): + await self.delete_devices(user_id, user_devices) @trace @measure_func("notify_device_update") @@ -832,7 +1015,7 @@ async def notify_device_update( # We may need to do some processing asynchronously for local user IDs. if self.hs.is_mine_id(user_id): - self._handle_new_device_update_async() + await self.handle_new_device_update() async def notify_user_signature_update( self, from_user_id: str, user_ids: List[str] @@ -852,101 +1035,12 @@ async def notify_user_signature_update( StreamKeyType.DEVICE_LIST, position, users=[from_user_id] ) - async def store_dehydrated_device( - self, - user_id: str, - device_id: Optional[str], - device_data: JsonDict, - initial_device_display_name: Optional[str] = None, - keys_for_device: Optional[JsonDict] = None, - ) -> str: - """Store a dehydrated device for a user, optionally storing the keys associated with - it as well. If the user had a previous dehydrated device, it is removed. - - Args: - user_id: the user that we are storing the device for - device_id: device id supplied by client - device_data: the dehydrated device information - initial_device_display_name: The display name to use for the device - keys_for_device: keys for the dehydrated device - Returns: - device id of the dehydrated device - """ - device_id = await self.check_device_registered( - user_id, - device_id, - initial_device_display_name, - ) - - time_now = self.clock.time_msec() - - old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data, time_now, keys_for_device - ) - - if old_device_id is not None: - await self.delete_devices(user_id, [old_device_id]) - - return device_id - - async def rehydrate_device( - self, user_id: str, access_token: str, device_id: str - ) -> dict: - """Process a rehydration request from the user. - - Args: - user_id: the user who is rehydrating the device - access_token: the access token used for the request - device_id: the ID of the device that will be rehydrated - Returns: - a dict containing {"success": True} - """ - success = await self.store.remove_dehydrated_device(user_id, device_id) - - if not success: - raise errors.NotFoundError() - - # If the dehydrated device was successfully deleted (the device ID - # matched the stored dehydrated device), then modify the access - # token and refresh token to use the dehydrated device's ID and - # copy the old device display name to the dehydrated device, - # and destroy the old device ID - old_device_id = await self.store.set_device_for_access_token( - access_token, device_id - ) - await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id) - old_device = await self.store.get_device(user_id, old_device_id) - if old_device is None: - raise errors.NotFoundError() - await self.store.update_device(user_id, device_id, old_device["display_name"]) - # can't call self.delete_device because that will clobber the - # access token so call the storage layer directly - await self.store.delete_devices(user_id, [old_device_id]) - await self.store.delete_e2e_keys_by_device( - user_id=user_id, device_id=old_device_id - ) - - # tell everyone that the old device is gone and that the dehydrated - # device has a new display name - await self.notify_device_update(user_id, [old_device_id, device_id]) - - return {"success": True} - - async def delete_dehydrated_device(self, user_id: str, device_id: str) -> None: - """ - Delete a stored dehydrated device. - - Args: - user_id: the user_id to delete the device from - device_id: id of the dehydrated device to delete - """ - success = await self.store.remove_dehydrated_device(user_id, device_id) - - if not success: - raise errors.NotFoundError() + async def handle_new_device_update(self) -> None: + if not self._is_main_device_list_writer: + return await super().handle_new_device_update() - await self.delete_devices(user_id, [device_id]) - await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) + self._handle_new_device_update_async() + return @wrap_as_background_process("_handle_new_device_update_async") async def _handle_new_device_update_async(self) -> None: @@ -956,6 +1050,8 @@ async def _handle_new_device_update_async(self) -> None: This happens in the background so as not to block the original request that generated the device update. """ + assert self._is_main_device_list_writer + if self._handle_new_device_update_is_processing: self._handle_new_device_update_new_data = True return @@ -973,7 +1069,7 @@ async def _handle_new_device_update_async(self) -> None: while True: self._handle_new_device_update_new_data = False - max_stream_id = self.store.get_device_stream_token() + max_stream_id = self.store.get_device_stream_token().stream rows = await self.store.get_uncoverted_outbound_room_pokes( stream_id, room_id ) @@ -1038,7 +1134,7 @@ async def _handle_new_device_update_async(self) -> None: # Notify replication that we've updated the device list stream. self.notifier.notify_replication() - if hosts: + if hosts and self.federation_sender: logger.info( "Sending device list update notif for %r to: %r", user_id, @@ -1158,9 +1254,10 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None: # Notify things that device lists need to be sent out. self.notifier.notify_replication() - await self.federation_sender.send_device_messages( - potentially_changed_hosts, immediate=False - ) + if self.federation_sender: + await self.federation_sender.send_device_messages( + potentially_changed_hosts, immediate=False + ) def _update_device_from_client_ips( @@ -1180,16 +1277,16 @@ class DeviceListWorkerUpdater: "Handles incoming device list updates from federation and contacts the main process over replication" def __init__(self, hs: "HomeServer"): - from synapse.replication.http.devices import ( - ReplicationMultiUserDevicesResyncRestServlet, - ) - + self.store = hs.get_datastores().main + self._notifier = hs.get_notifier() + self._main_device_list_writer = hs.config.worker.writers.device_lists[0] self._multi_user_device_resync_client = ( ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) ) async def multi_user_device_resync( - self, user_ids: List[str], mark_failed_as_stale: bool = True + self, + user_ids: List[str], ) -> Dict[str, Optional[JsonMapping]]: """ Like `user_device_resync` but operates on multiple users **from the same origin** @@ -1198,25 +1295,97 @@ async def multi_user_device_resync( Returns: Dict from User ID to the same Dict as `user_device_resync`. """ - # mark_failed_as_stale is not sent. Ensure this doesn't break expectations. - assert mark_failed_as_stale if not user_ids: # Shortcut empty requests return {} - return await self._multi_user_device_resync_client(user_ids=user_ids) + return await self._multi_user_device_resync_client( + instance_name=self._main_device_list_writer, + user_ids=user_ids, + ) + + async def process_cross_signing_key_update( + self, + user_id: str, + master_key: Optional[JsonDict], + self_signing_key: Optional[JsonDict], + ) -> List[str]: + """Process the given new master and self-signing key for the given remote user. + + Args: + user_id: The ID of the user these keys are for. + master_key: The dict of the cross-signing master key as returned by the + remote server. + self_signing_key: The dict of the cross-signing self-signing key as returned + by the remote server. + + Return: + The device IDs for the given keys. + """ + device_ids = [] + + current_keys_map = await self.store.get_e2e_cross_signing_keys_bulk([user_id]) + current_keys = current_keys_map.get(user_id) or {} + + if master_key and master_key != current_keys.get("master"): + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + _, verify_key = get_verify_key_from_cross_signing_key(master_key) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + device_ids.append(verify_key.version) + if self_signing_key and self_signing_key != current_keys.get("self_signing"): + await self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) + device_ids.append(verify_key.version) + + return device_ids + + async def handle_room_un_partial_stated(self, room_id: str) -> None: + """Handles sending appropriate device list updates in a room that has + gone from partial to full state. + """ + + pending_updates = ( + await self.store.get_pending_remote_device_list_updates_for_room(room_id) + ) + + for user_id, device_id in pending_updates: + logger.info( + "Got pending device list update in room %s: %s / %s", + room_id, + user_id, + device_id, + ) + position = await self.store.add_device_change_to_streams( + user_id, + [device_id], + room_ids=[room_id], + ) + + if not position: + # This should only happen if there are no updates, which + # shouldn't happen when we've passed in a non-empty set of + # device IDs. + continue + + self._notifier.on_new_event( + StreamKeyType.DEVICE_LIST, position, rooms=[room_id] + ) class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): - self.store = hs.get_datastores().main + super().__init__(hs) + self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler - self._notifier = hs.get_notifier() self._remote_edu_linearizer = Linearizer(name="remote_device_list") self._resync_linearizer = Linearizer(name="remote_device_resync") @@ -1640,74 +1809,3 @@ async def _user_device_resync_returning_failed( self._seen_updates[user_id] = {stream_id} return result, False - - async def process_cross_signing_key_update( - self, - user_id: str, - master_key: Optional[JsonDict], - self_signing_key: Optional[JsonDict], - ) -> List[str]: - """Process the given new master and self-signing key for the given remote user. - - Args: - user_id: The ID of the user these keys are for. - master_key: The dict of the cross-signing master key as returned by the - remote server. - self_signing_key: The dict of the cross-signing self-signing key as returned - by the remote server. - - Return: - The device IDs for the given keys. - """ - device_ids = [] - - current_keys_map = await self.store.get_e2e_cross_signing_keys_bulk([user_id]) - current_keys = current_keys_map.get(user_id) or {} - - if master_key and master_key != current_keys.get("master"): - await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) - _, verify_key = get_verify_key_from_cross_signing_key(master_key) - # verify_key is a VerifyKey from signedjson, which uses - # .version to denote the portion of the key ID after the - # algorithm and colon, which is the device ID - device_ids.append(verify_key.version) - if self_signing_key and self_signing_key != current_keys.get("self_signing"): - await self.store.set_e2e_cross_signing_key( - user_id, "self_signing", self_signing_key - ) - _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) - device_ids.append(verify_key.version) - - return device_ids - - async def handle_room_un_partial_stated(self, room_id: str) -> None: - """Handles sending appropriate device list updates in a room that has - gone from partial to full state. - """ - - pending_updates = ( - await self.store.get_pending_remote_device_list_updates_for_room(room_id) - ) - - for user_id, device_id in pending_updates: - logger.info( - "Got pending device list update in room %s: %s / %s", - room_id, - user_id, - device_id, - ) - position = await self.store.add_device_change_to_streams( - user_id, - [device_id], - room_ids=[room_id], - ) - - if not position: - # This should only happen if there are no updates, which - # shouldn't happen when we've passed in a non-empty set of - # device IDs. - continue - - self.device_handler.notifier.on_new_event( - StreamKeyType.DEVICE_LIST, position, rooms=[room_id] - ) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index e56bdb40720..b43cbd9c154 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -33,9 +33,6 @@ log_kv, set_tag, ) -from synapse.replication.http.devices import ( - ReplicationMultiUserDevicesResyncRestServlet, -) from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string @@ -56,9 +53,9 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine + self.device_handler = hs.get_device_handler() if hs.config.experimental.msc3814_enabled: self.event_sources = hs.get_event_sources() - self.device_handler = hs.get_device_handler() # We only need to poke the federation sender explicitly if its on the # same instance. Other federation sender instances will get notified by @@ -80,18 +77,6 @@ def __init__(self, hs: "HomeServer"): hs.config.worker.writers.to_device, ) - # The handler to call when we think a user's device list might be out of - # sync. We do all device list resyncing on the master instance, so if - # we're on a worker we hit the device resync replication API. - if hs.config.worker.worker_app is None: - self._multi_user_device_resync = ( - hs.get_device_handler().device_list_updater.multi_user_device_resync - ) - else: - self._multi_user_device_resync = ( - ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) - ) - # a rate limiter for room key requests. The keys are # (sending_user_id, sending_device_id). self._ratelimiter = Ratelimiter( @@ -213,7 +198,10 @@ async def _check_for_unknown_devices( await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,)) # Immediately attempt a resync in the background - run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id]) + run_in_background( + self.device_handler.device_list_updater.multi_user_device_resync, + user_ids=[sender_user_id], + ) async def send_device_message( self, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 6171aaf29fd..c6487ec4164 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -35,7 +35,6 @@ from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.types import ( JsonDict, JsonMapping, @@ -76,8 +75,10 @@ def __init__(self, hs: "HomeServer"): federation_registry = hs.get_federation_registry() - is_master = hs.config.worker.worker_app is None - if is_master: + # Only the first writer in the list should handle EDUs for signing key + # updates, so that we can use an in-memory linearizer instead of worker locks. + edu_writer = hs.config.worker.writers.device_lists[0] + if hs.get_instance_name() == edu_writer: edu_updater = SigningKeyEduUpdater(hs) # Only register this edu handler on master as it requires writing @@ -92,11 +93,14 @@ def __init__(self, hs: "HomeServer"): EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, edu_updater.incoming_signing_key_update, ) - - self.device_key_uploader = self.upload_device_keys_for_user else: - self.device_key_uploader = ( - ReplicationUploadKeysForUserRestServlet.make_client(hs) + federation_registry.register_instances_for_edu( + EduTypes.SIGNING_KEY_UPDATE, + [edu_writer], + ) + federation_registry.register_instances_for_edu( + EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, + [edu_writer], ) # doesn't really work as part of the generic query API, because the @@ -847,7 +851,7 @@ async def upload_keys_for_user( # TODO: Validate the JSON to make sure it has the right keys. device_keys = keys.get("device_keys", None) if device_keys: - await self.device_key_uploader( + await self.upload_device_keys_for_user( user_id=user_id, device_id=device_id, keys={"device_keys": device_keys}, @@ -904,9 +908,6 @@ async def upload_device_keys_for_user( device_keys: the `device_keys` of an /keys/upload request. """ - # This can only be called from the main process. - assert isinstance(self.device_handler, DeviceHandler) - time_now = self.clock.time_msec() device_keys = keys["device_keys"] @@ -998,9 +999,6 @@ async def upload_signing_keys_for_user( user_id: the user uploading the keys keys: the signing keys """ - # This can only be called from the main process. - assert isinstance(self.device_handler, DeviceHandler) - # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys if "master_key" in keys: @@ -1091,9 +1089,6 @@ async def upload_signatures_for_device_keys( Raises: SynapseError: if the signatures dict is not valid. """ - # This can only be called from the main process. - assert isinstance(self.device_handler, DeviceHandler) - failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -1467,9 +1462,6 @@ async def _retrieve_cross_signing_keys_for_remote_user( A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ - # This can only be called from the main process. - assert isinstance(self.device_handler, DeviceHandler) - try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 729b676163d..af0a1bc271f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -718,7 +718,7 @@ async def do_invite_join( await self.store.store_partial_state_room( room_id=room_id, servers=ret.servers_in_room, - device_lists_stream_id=self.store.get_device_stream_token(), + device_lists_stream_id=self.store.get_device_stream_token().stream, joined_via=origin, ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 1e738f484f9..5cec2b01e55 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -77,9 +77,6 @@ trace, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.http.devices import ( - ReplicationMultiUserDevicesResyncRestServlet, -) from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) @@ -180,12 +177,7 @@ def __init__(self, hs: "HomeServer"): self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) - if hs.config.worker.worker_app: - self._multi_user_device_resync = ( - ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) - ) - else: - self._device_list_updater = hs.get_device_handler().device_list_updater + self._device_list_updater = hs.get_device_handler().device_list_updater # When joining a room we need to queue any events for that room up. # For each room, a list of (pdu, origin) tuples. @@ -1544,12 +1536,7 @@ async def _resync_device(self, sender: str) -> None: await self._store.mark_remote_users_device_caches_as_stale((sender,)) # Immediately attempt a resync in the background - if self._config.worker.worker_app: - await self._multi_user_device_resync(user_ids=[sender]) - else: - await self._device_list_updater.multi_user_device_resync( - user_ids=[sender] - ) + await self._device_list_updater.multi_user_device_resync(user_ids=[sender]) except Exception: logger.exception("Failed to resync device for %s", sender) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 970013ef20c..f238c598f99 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -44,7 +44,6 @@ ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved -from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -859,9 +858,6 @@ class and RegisterDeviceReplicationServlet. refresh_token = None refresh_token_id = None - # This can only run on the main process. - assert isinstance(self.device_handler, DeviceHandler) - registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 94301add9ed..54116a9b724 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -36,17 +35,7 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - - # We don't need the device handler if password changing is disabled. - # This allows us to instantiate the SetPasswordHandler on the workers - # that have admin APIs for MAS - if self._auth_handler.can_change_password(): - # This can only be instantiated on the main process. - device_handler = hs.get_device_handler() - assert isinstance(device_handler, DeviceHandler) - self._device_handler: Optional[DeviceHandler] = device_handler - else: - self._device_handler = None + self._device_handler = hs.get_device_handler() async def set_password( self, @@ -58,9 +47,6 @@ async def set_password( if not self._auth_handler.can_change_password(): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) - # We should have this available only if password changing is enabled. - assert self._device_handler is not None - try: await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 7c5cf91aba5..b4d9b3ead4b 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -46,7 +46,6 @@ from synapse.api.constants import LoginType, ProfileFields from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement -from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -1181,8 +1180,6 @@ async def revoke_sessions_for_provider_session_id( ) -> None: """Revoke any devices and in-flight logins tied to a provider session. - Can only be called from the main process. - Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1191,11 +1188,6 @@ async def revoke_sessions_for_provider_session_id( sessions belonging to other users and log an error. """ - # It is expected that this is the main process. - assert isinstance(self._device_handler, DeviceHandler), ( - "revoking SSO sessions can only be called on the main process" - ) - # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 656fe323f33..5cfd15737f0 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -66,7 +66,6 @@ ON_LOGGED_OUT_CALLBACK, AuthHandler, ) -from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -922,8 +921,6 @@ def invalidate_access_token( ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user - Can only be called from the main process. - Added in Synapse v0.25.0. Args: @@ -936,10 +933,6 @@ def invalidate_access_token( Raises: synapse.api.errors.AuthError: the access token is invalid """ - assert isinstance(self._device_handler, DeviceHandler), ( - "invalidate_access_token can only be called on the main process" - ) - # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index d5000517143..555444fa3de 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -59,10 +59,10 @@ def register_servlets(self, hs: "HomeServer") -> None: account_data.register_servlets(hs, self) push.register_servlets(hs, self) state.register_servlets(hs, self) + devices.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: login.register_servlets(hs, self) register.register_servlets(hs, self) - devices.register_servlets(hs, self) delayed_events.register_servlets(hs, self) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 08cf9eff979..dce05f25b83 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -34,6 +34,92 @@ logger = logging.getLogger(__name__) +class ReplicationNotifyDeviceUpdateRestServlet(ReplicationEndpoint): + """Notify a device writer that a user's device list has changed. + + Request format: + + POST /_synapse/replication/notify_device_update/:user_id + + { + "device_ids": ["JLAFKJWSCS", "JLAFKJWSCS"] + } + """ + + NAME = "notify_device_update" + PATH_ARGS = ("user_id",) + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, device_ids: List[str] + ) -> JsonDict: + return {"device_ids": device_ids} + + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict, user_id: str + ) -> Tuple[int, JsonDict]: + device_ids = content["device_ids"] + + span = active_span() + if span: + span.set_tag("user_id", user_id) + span.set_tag("device_ids", f"{device_ids!r}") + + await self.device_handler.notify_device_update(user_id, device_ids) + + return 200, {} + + +class ReplicationNotifyUserSignatureUpdateRestServlet(ReplicationEndpoint): + """Notify a device writer that a user have made new signatures of other users. + + Request format: + + POST /_synapse/replication/notify_user_signature_update/:from_user_id + + { + "user_ids": ["@alice:example.org", "@bob:example.org", ...] + } + """ + + NAME = "notify_user_signature_update" + PATH_ARGS = ("from_user_id",) + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(from_user_id: str, user_ids: List[str]) -> JsonDict: # type: ignore[override] + return {"user_ids": user_ids} + + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict, from_user_id: str + ) -> Tuple[int, JsonDict]: + user_ids = content["user_ids"] + + span = active_span() + if span: + span.set_tag("from_user_id", from_user_id) + span.set_tag("user_ids", f"{user_ids!r}") + + await self.device_handler.notify_user_signature_update(from_user_id, user_ids) + + return 200, {} + + class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): """Ask master to resync the device list for multiple users from the same remote server by contacting their server. @@ -73,11 +159,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - from synapse.handlers.device import DeviceHandler - - handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) - self.device_list_updater = handler.device_list_updater + self.device_list_updater = hs.get_device_handler().device_list_updater self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -104,31 +186,7 @@ async def _handle_request( # type: ignore[override] class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): - """Ask master to upload keys for the user and send them out over federation to - update other servers. - - For now, only the master is permitted to handle key upload requests; - any worker can handle key query requests (since they're read-only). - - Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on - the main process to accomplish this. - - Request format for this endpoint (borrowed and expanded from KeyUploadServlet): - - POST /_synapse/replication/upload_keys_for_user - - { - "user_id": "", - "device_id": "", - "keys": { - ....this part can be found in KeyUploadServlet in rest/client/keys.py.... - or as defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload - } - } - - Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet - - """ + """Unused endpoint, kept for backwards compatibility during rollout.""" NAME = "upload_keys_for_user" PATH_ARGS = () @@ -165,6 +223,71 @@ async def _handle_request( # type: ignore[override] return 200, results +class ReplicationHandleNewDeviceUpdateRestServlet(ReplicationEndpoint): + """Wake up a device writer to send local device list changes as federation outbound pokes. + + Request format: + + POST /_synapse/replication/handle_new_device_update + + {} + """ + + NAME = "handle_new_device_update" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.device_handler = hs.get_device_handler() + + @staticmethod + async def _serialize_payload() -> JsonDict: # type: ignore[override] + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict + ) -> Tuple[int, JsonDict]: + await self.device_handler.handle_new_device_update() + return 200, {} + + +class ReplicationDeviceHandleRoomUnPartialStated(ReplicationEndpoint): + """Handles sending appropriate device list updates in a room that has + gone from partial to full state. + + Request format: + + POST /_synapse/replication/device_handle_room_un_partial_stated/:room_id + + {} + """ + + NAME = "device_handle_room_un_partial_stated" + PATH_ARGS = ("room_id",) + CACHE = True + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.device_handler = hs.get_device_handler() + + @staticmethod + async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict, room_id: str + ) -> Tuple[int, JsonDict]: + await self.device_handler.handle_room_un_partial_stated(room_id) + return 200, {} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + ReplicationNotifyDeviceUpdateRestServlet(hs).register(http_server) + ReplicationNotifyUserSignatureUpdateRestServlet(hs).register(http_server) ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server) + ReplicationHandleNewDeviceUpdateRestServlet(hs).register(http_server) ReplicationUploadKeysForUserRestServlet(hs).register(http_server) + ReplicationDeviceHandleRoomUnPartialStated(hs).register(http_server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 0bd5478cd35..608418de35d 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -115,7 +115,7 @@ async def on_rdata( all_room_ids: Set[str] = set() if stream_name == DeviceListsStream.NAME: if any(not row.is_signature and not row.hosts_calculated for row in rows): - prev_token = self.store.get_device_stream_token() + prev_token = self.store.get_device_stream_token().stream all_room_ids = await self.store.get_all_device_list_changes( prev_token, token ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1fafbb48c3e..e434bed3e58 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -72,6 +72,7 @@ ToDeviceStream, TypingStream, ) +from synapse.replication.tcp.streams._base import DeviceListsStream if TYPE_CHECKING: from synapse.server import HomeServer @@ -185,6 +186,12 @@ def __init__(self, hs: "HomeServer"): continue + if isinstance(stream, DeviceListsStream): + if hs.get_instance_name() in hs.config.worker.writers.device_lists: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker.worker_app is not None: continue diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index e55cdc0470e..32df4b244c6 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -51,7 +51,6 @@ from synapse.rest.admin.devices import ( DeleteDevicesRestServlet, DeviceRestServlet, - DevicesGetRestServlet, DevicesRestServlet, ) from synapse.rest.admin.event_reports import ( @@ -375,4 +374,5 @@ def register_servlets_for_msc3861_delegation( UserRestServletV2(hs).register(http_server) UsernameAvailableRestServlet(hs).register(http_server) UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server) - DevicesGetRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 09baf8ce219..c488bce58e1 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -51,9 +50,7 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) - self.device_handler = handler + self.device_handler = hs.get_device_handler() self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -113,7 +110,7 @@ async def on_PUT( return HTTPStatus.OK, {} -class DevicesGetRestServlet(RestServlet): +class DevicesRestServlet(RestServlet): """ Retrieve the given user's devices @@ -158,19 +155,6 @@ async def on_GET( return HTTPStatus.OK, {"devices": devices, "total": len(devices)} - -class DevicesRestServlet(DevicesGetRestServlet): - """ - Retrieve the given user's devices - """ - - PATTERNS = admin_patterns("/users/(?P[^/]*)/devices$", "v2") - - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - assert isinstance(self.device_worker_handler, DeviceHandler) - self.device_handler = self.device_worker_handler - async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: @@ -194,7 +178,7 @@ async def on_POST( if not isinstance(device_id, str): raise SynapseError(HTTPStatus.BAD_REQUEST, "device_id must be a string") - await self.device_handler.check_device_registered( + await self.device_worker_handler.check_device_registered( user_id=user_id, device_id=device_id ) @@ -211,9 +195,7 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) - self.device_handler = handler + self.device_handler = hs.get_device_handler() self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 0b075cc2f2a..0d7c205576a 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -27,7 +27,6 @@ from synapse._pydantic_compat import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError -from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -91,7 +90,6 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) self.device_handler = handler self.auth_handler = hs.get_auth_handler() @@ -179,14 +177,6 @@ class DeleteBody(RequestBodyModel): async def on_DELETE( self, request: SynapseRequest, device_id: str ) -> Tuple[int, JsonDict]: - # Can only be run on main process, as changes to device lists must - # happen on main. - if not self._is_main_process: - error_message = "DELETE on /devices/ must be routed to main process" - logger.error(error_message) - raise SynapseError(500, error_message) - assert isinstance(self.device_handler, DeviceHandler) - requester = await self.auth.get_user_by_req(request) try: @@ -231,14 +221,6 @@ class PutBody(RequestBodyModel): async def on_PUT( self, request: SynapseRequest, device_id: str ) -> Tuple[int, JsonDict]: - # Can only be run on main process, as changes to device lists must - # happen on main. - if not self._is_main_process: - error_message = "PUT on /devices/ must be routed to main process" - logger.error(error_message) - raise SynapseError(500, error_message) - assert isinstance(self.device_handler, DeviceHandler) - requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_and_validate_json_object_from_request(request, self.PutBody) @@ -317,7 +299,6 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -377,7 +358,6 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) self.device_handler = handler class PostBody(RequestBodyModel): @@ -517,7 +497,6 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = handler @@ -595,18 +574,14 @@ async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if ( - hs.config.worker.worker_app is None - and not hs.config.experimental.msc3861.enabled - ): + if not hs.config.experimental.msc3861.enabled: DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) - if hs.config.worker.worker_app is None: - if hs.config.experimental.msc2697_enabled: - DehydratedDeviceServlet(hs).register(http_server) - ClaimDehydratedDeviceServlet(hs).register(http_server) - if hs.config.experimental.msc3814_enabled: - DehydratedDeviceV2Servlet(hs).register(http_server) - DehydratedDeviceEventsServlet(hs).register(http_server) + if hs.config.experimental.msc2697_enabled: + DehydratedDeviceServlet(hs).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) + if hs.config.experimental.msc3814_enabled: + DehydratedDeviceV2Servlet(hs).register(http_server) + DehydratedDeviceEventsServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 7025662fdc0..09749b840fc 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -504,6 +504,5 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: OneTimeKeyServlet(hs).register(http_server) if hs.config.experimental.msc3983_appservice_otk_claims: UnstableOneTimeKeyServlet(hs).register(http_server) - if hs.config.worker.worker_app is None: - SigningKeyUploadServlet(hs).register(http_server) - SignaturesUploadServlet(hs).register(http_server) + SigningKeyUploadServlet(hs).register(http_server) + SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index e6b4a34d512..206865e9891 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -22,7 +22,6 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -42,9 +41,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) - self._device_handler = handler + self._device_handler = hs.get_device_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req( @@ -71,9 +68,7 @@ def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) - self._device_handler = handler + self._device_handler = hs.get_device_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req( diff --git a/synapse/server.py b/synapse/server.py index fd16abb9ead..2c1af1eddd0 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -587,11 +587,11 @@ def get_macaroon_generator(self) -> MacaroonGenerator: @cache_in_self def get_device_handler(self) -> DeviceWorkerHandler: - if self.config.worker.worker_app: - return DeviceWorkerHandler(self) - else: + if self.get_instance_name() in self.config.worker.writers.device_lists: return DeviceHandler(self) + return DeviceWorkerHandler(self) + @cache_in_self def get_device_message_handler(self) -> DeviceMessageHandler: return DeviceMessageHandler(self) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 6191f22cd6a..b8b640bd1a2 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -27,7 +27,6 @@ Dict, Iterable, List, - Literal, Mapping, Optional, Set, @@ -61,12 +60,12 @@ from synapse.types import ( JsonDict, JsonMapping, + MultiWriterStreamToken, StrCollection, get_verify_key_from_cross_signing_key, ) from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.lrucache import LruCache from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter @@ -86,6 +85,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): + _device_list_id_gen: MultiWriterIdGenerator + _instance_name: str + def __init__( self, database: DatabasePool, @@ -115,7 +117,11 @@ def __init__( ), ], sequence_name="device_lists_sequence", - writers=["master"], + writers=hs.config.worker.writers.device_lists, + ) + + self._is_device_list_writer = ( + self._instance_name in hs.config.worker.writers.device_lists ) device_list_max = self._device_list_id_gen.get_current_token() @@ -240,8 +246,8 @@ def device_lists_in_rooms_have_changed( for room_id in room_ids: self._device_list_room_stream_cache.entity_has_changed(room_id, token) - def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() + def get_device_stream_token(self) -> MultiWriterStreamToken: + return MultiWriterStreamToken.from_generator(self._device_list_id_gen) def get_device_stream_id_generator(self) -> MultiWriterIdGenerator: return self._device_list_id_gen @@ -282,6 +288,145 @@ def count_devices_by_users_txn( "count_devices_by_users", count_devices_by_users_txn, user_ids ) + async def store_device( + self, + user_id: str, + device_id: str, + initial_device_display_name: Optional[str], + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, + ) -> bool: + """Ensure the given device is known; add it to the store if not + + Args: + user_id: id of user associated with the device + device_id: id of device + initial_device_display_name: initial displayname of the device. + Ignored if device exists. + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from a OIDC login. + + Returns: + Whether the device was inserted or an existing device existed with that ID. + + Raises: + StoreError: if the device is already in use + """ + try: + inserted = await self.db_pool.simple_upsert( + "devices", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={}, + insertion_values={ + "display_name": initial_device_display_name, + "hidden": False, + }, + desc="store_device", + ) + await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) + + if not inserted: + # if the device already exists, check if it's a real device, or + # if the device ID is reserved by something else + hidden = await self.db_pool.simple_select_one_onecol( + "devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="hidden", + ) + if hidden: + raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + + if auth_provider_id and auth_provider_session_id: + await self.db_pool.simple_insert( + "device_auth_providers", + values={ + "user_id": user_id, + "device_id": device_id, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="store_device_auth_provider", + ) + + return inserted + except StoreError: + raise + except Exception as e: + logger.error( + "store_device with device_id=%s(%r) user_id=%s(%r)" + " display_name=%s(%r) failed: %s", + type(device_id).__name__, + device_id, + type(user_id).__name__, + user_id, + type(initial_device_display_name).__name__, + initial_device_display_name, + e, + ) + raise StoreError(500, "Problem storing device.") + + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: + """Deletes several devices. + + Args: + user_id: The ID of the user which owns the devices + device_ids: The IDs of the devices to delete + """ + + def _delete_devices_txn(txn: LoggingTransaction, device_ids: List[str]) -> None: + self.db_pool.simple_delete_many_txn( + txn, + table="devices", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id, "hidden": False}, + ) + + self.db_pool.simple_delete_many_txn( + txn, + table="device_auth_providers", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream_bulk( + txn, self.get_device, [(user_id, device_id) for device_id in device_ids] + ) + + for batch in batch_iter(device_ids, 100): + await self.db_pool.runInteraction( + "delete_devices", _delete_devices_txn, batch + ) + + async def update_device( + self, user_id: str, device_id: str, new_display_name: Optional[str] = None + ) -> None: + """Update a device. Only updates the device if it is not marked as + hidden. + + Args: + user_id: The ID of the user which owns the device + device_id: The ID of the device to update + new_display_name: new displayname for device; None to leave unchanged + Raises: + StoreError: if the device is not found + """ + updates = {} + if new_display_name is not None: + updates["display_name"] = new_display_name + if not updates: + return None + await self.db_pool.simple_update_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + updatevalues=updates, + desc="update_device", + ) + await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) + @cached() async def get_device( self, user_id: str, device_id: str @@ -375,7 +520,11 @@ async def get_device_updates_by_remote( - The list of updates, where each update is a pair of EDU type and EDU contents. """ - now_stream_id = self.get_device_stream_token() + # Here, we don't use the individual instances positions, as we only + # record the last stream position we've sent to a destination. This + # means we have to wait for all the writers to catch up before sending + # device list updates, which is fine. + now_stream_id = self.get_device_stream_token().stream if from_stream_id == now_stream_id: return now_stream_id, [] @@ -752,6 +901,8 @@ async def add_user_signature_change_to_streams( Returns: The new stream ID. """ + if not self._is_device_list_writer: + raise Exception("Can only be called on device list writers") async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( @@ -874,8 +1025,8 @@ async def get_cached_devices_for_user( @cancellable async def get_all_devices_changed( self, - from_key: int, - to_key: int, + from_key: MultiWriterStreamToken, + to_key: MultiWriterStreamToken, ) -> Set[str]: """Get all users whose devices have changed in the given range. @@ -890,7 +1041,9 @@ async def get_all_devices_changed( (exclusive) until `to_key` (inclusive). """ - result = self._device_list_stream_cache.get_all_entities_changed(from_key) + result = self._device_list_stream_cache.get_all_entities_changed( + from_key.stream + ) if result.hit: # We know which users might have changed devices. @@ -906,24 +1059,34 @@ async def get_all_devices_changed( # If the cache didn't tell us anything, we just need to query the full # range. sql = """ - SELECT DISTINCT user_id FROM device_lists_stream + SELECT user_id, stream_id, instance_name + FROM device_lists_stream WHERE ? < stream_id AND stream_id <= ? """ rows = await self.db_pool.execute( "get_all_devices_changed", sql, - from_key, - to_key, + from_key.stream, + to_key.get_max_stream_pos(), ) - return {u for (u,) in rows} + return { + user_id + for (user_id, stream_id, instance_name) in rows + if MultiWriterStreamToken.is_stream_position_in_range( + low=from_key, + high=to_key, + instance_name=instance_name, + pos=stream_id, + ) + } @cancellable async def get_users_whose_devices_changed( self, - from_key: int, + from_key: MultiWriterStreamToken, user_ids: Collection[str], - to_key: Optional[int] = None, + to_key: Optional[MultiWriterStreamToken] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. @@ -943,7 +1106,7 @@ async def get_users_whose_devices_changed( # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. user_ids_to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key + user_ids, from_key.stream ) # If an empty set was returned, there's nothing to do. @@ -951,11 +1114,16 @@ async def get_users_whose_devices_changed( return set() if to_key is None: - to_key = self._device_list_id_gen.get_current_token() + to_key = self.get_device_stream_token() - def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: + def _get_users_whose_devices_changed_txn( + txn: LoggingTransaction, + from_key: MultiWriterStreamToken, + to_key: MultiWriterStreamToken, + ) -> Set[str]: sql = """ - SELECT DISTINCT user_id FROM device_lists_stream + SELECT user_id, stream_id, instance_name + FROM device_lists_stream WHERE ? < stream_id AND stream_id <= ? AND %s """ @@ -966,17 +1134,32 @@ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk ) - txn.execute(sql % (clause,), [from_key, to_key] + args) - changes.update(user_id for (user_id,) in txn) + txn.execute( + sql % (clause,), + [from_key.stream, to_key.get_max_stream_pos()] + args, + ) + changes.update( + user_id + for (user_id, stream_id, instance_name) in txn + if MultiWriterStreamToken.is_stream_position_in_range( + low=from_key, + high=to_key, + instance_name=instance_name, + pos=stream_id, + ) + ) return changes return await self.db_pool.runInteraction( - "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn + "get_users_whose_devices_changed", + _get_users_whose_devices_changed_txn, + from_key, + to_key, ) async def get_users_whose_signatures_changed( - self, user_id: str, from_key: int + self, user_id: str, from_key: MultiWriterStreamToken ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. @@ -989,18 +1172,31 @@ async def get_users_whose_signatures_changed( A set of user IDs with updated signatures. """ - if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): - sql = """ - SELECT DISTINCT user_ids FROM user_signature_stream - WHERE from_user_id = ? AND stream_id > ? - """ - rows = await self.db_pool.execute( - "get_users_whose_signatures_changed", sql, user_id, from_key - ) - return {user for row in rows for user in db_to_json(row[0])} - else: + if not self._user_signature_stream_cache.has_entity_changed( + user_id, from_key.stream + ): return set() + sql = """ + SELECT user_ids, stream_id, instance_name + FROM user_signature_stream + WHERE from_user_id = ? AND stream_id > ? + """ + rows = await self.db_pool.execute( + "get_users_whose_signatures_changed", sql, user_id, from_key.stream + ) + return { + user + for (user_ids, stream_id, instance_name) in rows + if MultiWriterStreamToken.is_stream_position_in_range( + low=from_key, + high=None, + instance_name=instance_name, + pos=stream_id, + ) + for user in db_to_json(user_ids) + } + async def get_all_device_list_changes_for_remotes( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: @@ -1254,9 +1450,7 @@ def _store_dehydrated_device_txn( if keys: device_keys = keys.get("device_keys", None) if device_keys: - # Type ignore - this function is defined on EndToEndKeyStore which we do - # have access to due to hs.get_datastore() "magic" - self._set_e2e_device_keys_txn( # type: ignore[attr-defined] + self._set_e2e_device_keys_txn( txn, user_id, device_id, time, device_keys ) @@ -1486,7 +1680,10 @@ async def _get_min_device_lists_changes_in_room(self) -> int: @cancellable async def get_device_list_changes_in_rooms( - self, room_ids: Collection[str], from_id: int, to_id: int + self, + room_ids: Collection[str], + from_token: MultiWriterStreamToken, + to_token: MultiWriterStreamToken, ) -> Optional[Set[str]]: """Return the set of users whose devices have changed in the given rooms since the given stream ID. @@ -1499,41 +1696,50 @@ async def get_device_list_changes_in_rooms( min_stream_id = await self._get_min_device_lists_changes_in_room() - if min_stream_id > from_id: + # XXX: is that right? + if min_stream_id > from_token.stream: return None changed_room_ids = self._device_list_room_stream_cache.get_entities_changed( - room_ids, from_id + room_ids, from_token.stream ) if not changed_room_ids: return set() sql = """ - SELECT DISTINCT user_id FROM device_lists_changes_in_room + SELECT user_id, stream_id, instance_name + FROM device_lists_changes_in_room WHERE {clause} AND stream_id > ? AND stream_id <= ? """ def _get_device_list_changes_in_rooms_txn( txn: LoggingTransaction, - clause: str, - args: List[Any], + chunk: list[str], ) -> Set[str]: - txn.execute(sql.format(clause=clause), args) - return {user_id for (user_id,) in txn} - - changes = set() - for chunk in batch_iter(changed_room_ids, 1000): clause, args = make_in_list_sql_clause( self.database_engine, "room_id", chunk ) - args.append(from_id) - args.append(to_id) + args.append(from_token.stream) + args.append(to_token.get_max_stream_pos()) + + txn.execute(sql.format(clause=clause), args) + return { + user_id + for (user_id, stream_id, instance_name) in txn + if MultiWriterStreamToken.is_stream_position_in_range( + low=from_token, + high=to_token, + instance_name=instance_name, + pos=stream_id, + ) + } + changes = set() + for chunk in batch_iter(changed_room_ids, 1000): changes |= await self.db_pool.runInteraction( "get_device_list_changes_in_rooms", _get_device_list_changes_in_rooms_txn, - clause, - args, + chunk, ) return changes @@ -1601,380 +1807,56 @@ async def get_destinations_for_device(self, stream_id: int) -> StrCollection: desc="get_destinations_for_device", ) + async def update_remote_device_list_cache_entry( + self, user_id: str, device_id: str, content: JsonDict, stream_id: str + ) -> None: + """Updates a single device in the cache of a remote user's devicelist. -class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self._instance_name = hs.get_instance_name() + Note: assumes that we are the only thread that can be updating this user's + device list. - self.db_pool.updates.register_background_index_update( - "device_lists_stream_idx", - index_name="device_lists_stream_user_id", - table="device_lists_stream", - columns=["user_id", "device_id"], + Args: + user_id: User to update device list for + device_id: ID of decivice being updated + content: new data on this device + stream_id: the version of the device list + """ + await self.db_pool.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, + device_id, + content, + stream_id, ) - # create a unique index on device_lists_remote_cache - self.db_pool.updates.register_background_index_update( - "device_lists_remote_cache_unique_idx", - index_name="device_lists_remote_cache_unique_id", - table="device_lists_remote_cache", - columns=["user_id", "device_id"], - unique=True, - ) + def _update_remote_device_list_cache_entry_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + content: JsonDict, + stream_id: str, + ) -> None: + """Delete, update or insert a cache entry for this (user, device) pair.""" + if content.get("deleted"): + self.db_pool.simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + else: + self.db_pool.simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"content": json_encoder.encode(content)}, + ) - # And one on device_lists_remote_extremeties - self.db_pool.updates.register_background_index_update( - "device_lists_remote_extremeties_unique_idx", - index_name="device_lists_remote_extremeties_unique_idx", - table="device_lists_remote_extremeties", - columns=["user_id"], - unique=True, - ) - - # once they complete, we can remove the old non-unique indexes. - self.db_pool.updates.register_background_update_handler( - DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, - self._drop_device_list_streams_non_unique_indexes, - ) - - # clear out duplicate device list outbound pokes - self.db_pool.updates.register_background_update_handler( - BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, - self._remove_duplicate_outbound_pokes, - ) - - self.db_pool.updates.register_background_index_update( - "device_lists_changes_in_room_by_room_index", - index_name="device_lists_changes_in_room_by_room_idx", - table="device_lists_changes_in_room", - columns=["room_id", "stream_id"], - ) - - async def _drop_device_list_streams_non_unique_indexes( - self, progress: JsonDict, batch_size: int - ) -> int: - def f(conn: LoggingDatabaseConnection) -> None: - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") - txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") - txn.close() - - await self.db_pool.runWithConnection(f) - await self.db_pool.updates._end_background_update( - DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES - ) - return 1 - - async def _remove_duplicate_outbound_pokes( - self, progress: JsonDict, batch_size: int - ) -> int: - # for some reason, we have accumulated duplicate entries in - # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less - # efficient. - # - # For each duplicate, we delete all the existing rows and put one back. - - last_row = progress.get( - "last_row", - {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, - ) - - def _txn(txn: LoggingTransaction) -> int: - clause, args = make_tuple_comparison_clause( - [ - ("stream_id", last_row["stream_id"]), - ("destination", last_row["destination"]), - ("user_id", last_row["user_id"]), - ("device_id", last_row["device_id"]), - ] - ) - sql = f""" - SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts - FROM device_lists_outbound_pokes - WHERE {clause} - GROUP BY stream_id, destination, user_id, device_id - HAVING count(*) > 1 - ORDER BY stream_id, destination, user_id, device_id - LIMIT ? - """ - txn.execute(sql, args + [batch_size]) - rows = txn.fetchall() - - stream_id, destination, user_id, device_id = None, None, None, None - for stream_id, destination, user_id, device_id, _ in rows: - self.db_pool.simple_delete_txn( - txn, - "device_lists_outbound_pokes", - { - "stream_id": stream_id, - "destination": destination, - "user_id": user_id, - "device_id": device_id, - }, - ) - - self.db_pool.simple_insert_txn( - txn, - "device_lists_outbound_pokes", - { - "stream_id": stream_id, - "instance_name": self._instance_name, - "destination": destination, - "user_id": user_id, - "device_id": device_id, - "sent": False, - }, - ) - - if rows: - self.db_pool.updates._background_update_progress_txn( - txn, - BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, - { - "last_row": { - "stream_id": stream_id, - "destination": destination, - "user_id": user_id, - "device_id": device_id, - } - }, - ) - - return len(rows) - - rows = await self.db_pool.runInteraction( - BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn - ) - - if not rows: - await self.db_pool.updates._end_background_update( - BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES - ) - - return rows - - -class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - # Map of (user_id, device_id) -> bool. If there is an entry that implies - # the device exists. - self.device_id_exists_cache: LruCache[Tuple[str, str], Literal[True]] = ( - LruCache(cache_name="device_id_exists", max_size=10000) - ) - - async def store_device( - self, - user_id: str, - device_id: str, - initial_device_display_name: Optional[str], - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, - ) -> bool: - """Ensure the given device is known; add it to the store if not - - Args: - user_id: id of user associated with the device - device_id: id of device - initial_device_display_name: initial displayname of the device. - Ignored if device exists. - auth_provider_id: The SSO IdP the user used, if any. - auth_provider_session_id: The session ID (sid) got from a OIDC login. - - Returns: - Whether the device was inserted or an existing device existed with that ID. - - Raises: - StoreError: if the device is already in use - """ - key = (user_id, device_id) - if self.device_id_exists_cache.get(key, None): - return False - - try: - inserted = await self.db_pool.simple_upsert( - "devices", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={}, - insertion_values={ - "display_name": initial_device_display_name, - "hidden": False, - }, - desc="store_device", - ) - await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) - - if not inserted: - # if the device already exists, check if it's a real device, or - # if the device ID is reserved by something else - hidden = await self.db_pool.simple_select_one_onecol( - "devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="hidden", - ) - if hidden: - raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - - if auth_provider_id and auth_provider_session_id: - await self.db_pool.simple_insert( - "device_auth_providers", - values={ - "user_id": user_id, - "device_id": device_id, - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - desc="store_device_auth_provider", - ) - - self.device_id_exists_cache.set(key, True) - return inserted - except StoreError: - raise - except Exception as e: - logger.error( - "store_device with device_id=%s(%r) user_id=%s(%r)" - " display_name=%s(%r) failed: %s", - type(device_id).__name__, - device_id, - type(user_id).__name__, - user_id, - type(initial_device_display_name).__name__, - initial_device_display_name, - e, - ) - raise StoreError(500, "Problem storing device.") - - async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: - """Deletes several devices. - - Args: - user_id: The ID of the user which owns the devices - device_ids: The IDs of the devices to delete - """ - - def _delete_devices_txn(txn: LoggingTransaction, device_ids: List[str]) -> None: - self.db_pool.simple_delete_many_txn( - txn, - table="devices", - column="device_id", - values=device_ids, - keyvalues={"user_id": user_id, "hidden": False}, - ) - - self.db_pool.simple_delete_many_txn( - txn, - table="device_auth_providers", - column="device_id", - values=device_ids, - keyvalues={"user_id": user_id}, - ) - self._invalidate_cache_and_stream_bulk( - txn, self.get_device, [(user_id, device_id) for device_id in device_ids] - ) - - for batch in batch_iter(device_ids, 100): - await self.db_pool.runInteraction( - "delete_devices", _delete_devices_txn, batch - ) - - for device_id in device_ids: - self.device_id_exists_cache.invalidate((user_id, device_id)) - - async def update_device( - self, user_id: str, device_id: str, new_display_name: Optional[str] = None - ) -> None: - """Update a device. Only updates the device if it is not marked as - hidden. - - Args: - user_id: The ID of the user which owns the device - device_id: The ID of the device to update - new_display_name: new displayname for device; None to leave unchanged - Raises: - StoreError: if the device is not found - """ - updates = {} - if new_display_name is not None: - updates["display_name"] = new_display_name - if not updates: - return None - await self.db_pool.simple_update_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - updatevalues=updates, - desc="update_device", - ) - await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) - - async def update_remote_device_list_cache_entry( - self, user_id: str, device_id: str, content: JsonDict, stream_id: str - ) -> None: - """Updates a single device in the cache of a remote user's devicelist. - - Note: assumes that we are the only thread that can be updating this user's - device list. - - Args: - user_id: User to update device list for - device_id: ID of decivice being updated - content: new data on this device - stream_id: the version of the device list - """ - await self.db_pool.runInteraction( - "update_remote_device_list_cache_entry", - self._update_remote_device_list_cache_entry_txn, - user_id, - device_id, - content, - stream_id, - ) - - def _update_remote_device_list_cache_entry_txn( - self, - txn: LoggingTransaction, - user_id: str, - device_id: str, - content: JsonDict, - stream_id: str, - ) -> None: - """Delete, update or insert a cache entry for this (user, device) pair.""" - if content.get("deleted"): - self.db_pool.simple_delete_txn( - txn, - table="device_lists_remote_cache", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - - txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) - else: - self.db_pool.simple_upsert_txn( - txn, - table="device_lists_remote_cache", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"content": json_encoder.encode(content)}, - ) - - txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) - txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) + txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) self.db_pool.simple_upsert_txn( @@ -2055,14 +1937,21 @@ async def add_device_change_to_streams( The maximum stream ID of device list updates that were added to the database, or None if no updates were added. """ + if not self._is_device_list_writer: + raise Exception("Can only be called on device list writers") + if not device_ids: return None context = get_active_span_text_map() def add_device_changes_txn( - txn: LoggingTransaction, stream_ids: List[int] - ) -> None: + txn: LoggingTransaction, + ) -> int: + stream_ids = self._device_list_id_gen.get_next_mult_txn( + txn, len(device_ids) + ) + self._add_device_change_to_stream_txn( txn, user_id, @@ -2079,16 +1968,12 @@ def add_device_changes_txn( context, ) - async with self._device_list_id_gen.get_next_mult( - len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_change_to_stream", - add_device_changes_txn, - stream_ids, - ) + return stream_ids[-1] - return stream_ids[-1] + return await self.db_pool.runInteraction( + "add_device_change_to_stream", + add_device_changes_txn, + ) def _add_device_change_to_stream_txn( self, @@ -2287,6 +2172,8 @@ async def get_uncoverted_outbound_room_pokes( A list of user ID, device ID, room ID, stream ID and optional opentracing context, in order of ascending (stream ID, room ID). """ + if not self._is_device_list_writer: + raise Exception("Can only be called on device list writers") sql = """ SELECT user_id, device_id, room_id, stream_id, opentracing_context @@ -2340,6 +2227,9 @@ async def add_device_list_outbound_pokes( """Queue the device update to be sent to the given set of hosts, calculated from the room ID. """ + if not self._is_device_list_writer: + raise Exception("Can only be called on device list writers") + if not hosts: return @@ -2368,6 +2258,8 @@ async def add_remote_device_list_to_pending( """Add a device list update to the table tracking remote device list updates during partial joins. """ + if not self._is_device_list_writer: + raise Exception("Can only be called on device list writers") async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.simple_upsert( @@ -2390,6 +2282,11 @@ async def get_pending_remote_device_list_updates_for_room( the room. """ + # The device list stream is a multi-writer stream, but when we partially + # join a room, we only record the minimum stream ID. This means that we + # may be returning a device update that was already sent through + # federation here in case of concurrent writes. This is absolutely fine, + # sending a device update multiple times through federation is safe min_device_stream_id = await self.db_pool.simple_select_one_onecol( table="partial_state_rooms", keyvalues={ @@ -2460,3 +2357,176 @@ async def set_device_change_last_converted_pos( }, desc="set_device_change_last_converted_pos", ) + + +class DeviceBackgroundUpdateStore(SQLBaseStore): + _instance_name: str + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + self.db_pool.updates.register_background_index_update( + "device_lists_stream_idx", + index_name="device_lists_stream_user_id", + table="device_lists_stream", + columns=["user_id", "device_id"], + ) + + # create a unique index on device_lists_remote_cache + self.db_pool.updates.register_background_index_update( + "device_lists_remote_cache_unique_idx", + index_name="device_lists_remote_cache_unique_id", + table="device_lists_remote_cache", + columns=["user_id", "device_id"], + unique=True, + ) + + # And one on device_lists_remote_extremeties + self.db_pool.updates.register_background_index_update( + "device_lists_remote_extremeties_unique_idx", + index_name="device_lists_remote_extremeties_unique_idx", + table="device_lists_remote_extremeties", + columns=["user_id"], + unique=True, + ) + + # once they complete, we can remove the old non-unique indexes. + self.db_pool.updates.register_background_update_handler( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, + self._drop_device_list_streams_non_unique_indexes, + ) + + # clear out duplicate device list outbound pokes + self.db_pool.updates.register_background_update_handler( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, + self._remove_duplicate_outbound_pokes, + ) + + self.db_pool.updates.register_background_index_update( + "device_lists_changes_in_room_by_room_index", + index_name="device_lists_changes_in_room_by_room_idx", + table="device_lists_changes_in_room", + columns=["room_id", "stream_id"], + ) + + async def _drop_device_list_streams_non_unique_indexes( + self, progress: JsonDict, batch_size: int + ) -> int: + def f(conn: LoggingDatabaseConnection) -> None: + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") + txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") + txn.close() + + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES + ) + return 1 + + async def _remove_duplicate_outbound_pokes( + self, progress: JsonDict, batch_size: int + ) -> int: + # for some reason, we have accumulated duplicate entries in + # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + # efficient. + # + # For each duplicate, we delete all the existing rows and put one back. + + last_row = progress.get( + "last_row", + {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, + ) + + def _txn(txn: LoggingTransaction) -> int: + clause, args = make_tuple_comparison_clause( + [ + ("stream_id", last_row["stream_id"]), + ("destination", last_row["destination"]), + ("user_id", last_row["user_id"]), + ("device_id", last_row["device_id"]), + ] + ) + sql = f""" + SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts + FROM device_lists_outbound_pokes + WHERE {clause} + GROUP BY stream_id, destination, user_id, device_id + HAVING count(*) > 1 + ORDER BY stream_id, destination, user_id, device_id + LIMIT ? + """ + txn.execute(sql, args + [batch_size]) + rows = txn.fetchall() + + stream_id, destination, user_id, device_id = None, None, None, None + for stream_id, destination, user_id, device_id, _ in rows: + self.db_pool.simple_delete_txn( + txn, + "device_lists_outbound_pokes", + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + }, + ) + + self.db_pool.simple_insert_txn( + txn, + "device_lists_outbound_pokes", + { + "stream_id": stream_id, + "instance_name": self._instance_name, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + "sent": False, + }, + ) + + if rows: + self.db_pool.updates._background_update_progress_txn( + txn, + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, + { + "last_row": { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + } + }, + ) + + return len(rows) + + rows = await self.db_pool.runInteraction( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn + ) + + if not rows: + await self.db_pool.updates._end_background_update( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES + ) + + return rows + + +class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): + _instance_name: str + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 341e7014d69..cfbe815620d 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -59,7 +59,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import JsonDict, JsonMapping +from synapse.types import JsonDict, JsonMapping, MultiWriterStreamToken from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable @@ -120,6 +120,20 @@ def __init__( self.hs.config.federation.allow_device_name_lookup_over_federation ) + self._cross_signing_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="e2e_cross_signing_keys", + instance_name=self._instance_name, + tables=[ + ("e2e_cross_signing_keys", "instance_name", "stream_id"), + ], + sequence_name="e2e_cross_signing_keys_sequence", + # No one reads the stream positions, so we're allowed to have an empty list of writers + writers=[], + ) + def process_replication_rows( self, stream_name: str, @@ -145,7 +159,12 @@ async def get_e2e_device_keys_for_federation_query( Returns: (stream_id, devices) """ - now_stream_id = self.get_device_stream_token() + # Here, we don't use the individual instances positions, as we *need* to + # give out the stream_id as an integer in the federation API. + # This means that we'll potentially return the same data twice with a + # different stream_id, and invalidate cache more often than necessary, + # which is fine overall. + now_stream_id = self.get_device_stream_token().stream # We need to be careful with the caching here, as we need to always # return *all* persisted devices, however there may be a lag between a @@ -164,8 +183,10 @@ async def get_e2e_device_keys_for_federation_query( # have to check for potential invalidations after the # `now_stream_id`. sql = """ - SELECT user_id FROM device_lists_stream + SELECT 1 + FROM device_lists_stream WHERE stream_id >= ? AND user_id = ? + LIMIT 1 """ rows = await self.db_pool.execute( "get_e2e_device_keys_for_federation_query_check", @@ -1117,7 +1138,7 @@ def _get_all_user_signature_changes_for_remotes_txn( ) @abc.abstractmethod - def get_device_stream_token(self) -> int: + def get_device_stream_token(self) -> MultiWriterStreamToken: """Get the current stream id from the _device_list_id_gen""" ... @@ -1540,27 +1561,44 @@ def impl(txn: LoggingTransaction) -> Optional[int]: impl, ) + async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: + def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: + log_kv( + { + "message": "Deleting keys for device", + "device_id": device_id, + "user_id": user_id, + } + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + self.db_pool.simple_delete_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) -class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self._cross_signing_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="e2e_cross_signing_keys", - instance_name=self._instance_name, - tables=[ - ("e2e_cross_signing_keys", "instance_name", "stream_id"), - ], - sequence_name="e2e_cross_signing_keys_sequence", - writers=["master"], + await self.db_pool.runInteraction( + "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) async def set_e2e_device_keys( @@ -1632,46 +1670,6 @@ def _set_e2e_device_keys_txn( log_kv({"message": "Device keys stored."}) return True - async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: - def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: - log_kv( - { - "message": "Deleting keys for device", - "device_id": device_id, - "user_id": user_id, - } - ) - self.db_pool.simple_delete_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - self.db_pool.simple_delete_txn( - txn, - table="dehydrated_devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="e2e_fallback_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) - ) - - await self.db_pool.runInteraction( - "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn - ) - def _set_e2e_cross_signing_key_txn( self, txn: LoggingTransaction, @@ -1794,3 +1792,13 @@ async def store_e2e_cross_signing_signatures( ], desc="add_e2e_signing_key", ) + + +class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 99643315107..5794eb13e1e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -36,7 +36,6 @@ ) import attr -from immutabledict import immutabledict from synapse.api.constants import EduTypes from synapse.replication.tcp.streams import ReceiptsStream @@ -166,25 +165,7 @@ def __init__( def get_max_receipt_stream_id(self) -> MultiWriterStreamToken: """Get the current max stream ID for receipts stream""" - min_pos = self._receipts_id_gen.get_current_token() - - positions = {} - if isinstance(self._receipts_id_gen, MultiWriterIdGenerator): - # The `min_pos` is the minimum position that we know all instances - # have finished persisting to, so we only care about instances whose - # positions are ahead of that. (Instance positions can be behind the - # min position as there are times we can work out that the minimum - # position is ahead of the naive minimum across all current - # positions. See MultiWriterIdGenerator for details) - positions = { - i: p - for i, p in self._receipts_id_gen.get_positions().items() - if p > min_pos - } - - return MultiWriterStreamToken( - stream=min_pos, instance_map=immutabledict(positions) - ) + return MultiWriterStreamToken.from_generator(self._receipts_id_gen) def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: return self._receipts_id_gen.get_current_token_for_writer(instance_name) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 40c551bcb4b..7cc34757345 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1936,6 +1936,58 @@ def _replace_refresh_token_txn(txn: LoggingTransaction) -> None: "replace_refresh_token", _replace_refresh_token_txn ) + async def set_device_for_refresh_token( + self, user_id: str, old_device_id: str, device_id: str + ) -> None: + """Moves refresh tokens from old device to current device + + Args: + user_id: The user of the devices. + old_device_id: The old device. + device_id: The new device ID. + Returns: + None + """ + + await self.db_pool.simple_update( + "refresh_tokens", + keyvalues={"user_id": user_id, "device_id": old_device_id}, + updatevalues={"device_id": device_id}, + desc="set_device_for_refresh_token", + ) + + def _set_device_for_access_token_txn( + self, txn: LoggingTransaction, token: str, device_id: str + ) -> str: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, "access_tokens", {"token": token}, "device_id" + ) + + self.db_pool.simple_update_txn( + txn, "access_tokens", {"token": token}, {"device_id": device_id} + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,)) + + return old_device_id + + async def set_device_for_access_token(self, token: str, device_id: str) -> str: + """Sets the device ID associated with an access token. + + Args: + token: The access token to modify. + device_id: The new device ID. + Returns: + The old device ID associated with the access token. + """ + + return await self.db_pool.runInteraction( + "set_device_for_access_token", + self._set_device_for_access_token_txn, + token, + device_id, + ) + async def add_login_token_to_user( self, user_id: str, @@ -2470,58 +2522,6 @@ async def add_refresh_token_to_user( return next_id - async def set_device_for_refresh_token( - self, user_id: str, old_device_id: str, device_id: str - ) -> None: - """Moves refresh tokens from old device to current device - - Args: - user_id: The user of the devices. - old_device_id: The old device. - device_id: The new device ID. - Returns: - None - """ - - await self.db_pool.simple_update( - "refresh_tokens", - keyvalues={"user_id": user_id, "device_id": old_device_id}, - updatevalues={"device_id": device_id}, - desc="set_device_for_refresh_token", - ) - - def _set_device_for_access_token_txn( - self, txn: LoggingTransaction, token: str, device_id: str - ) -> str: - old_device_id = self.db_pool.simple_select_one_onecol_txn( - txn, "access_tokens", {"token": token}, "device_id" - ) - - self.db_pool.simple_update_txn( - txn, "access_tokens", {"token": token}, {"device_id": device_id} - ) - - self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,)) - - return old_device_id - - async def set_device_for_access_token(self, token: str, device_id: str) -> str: - """Sets the device ID associated with an access token. - - Args: - token: The access token to modify. - device_id: The new device ID. - Returns: - The old device ID associated with the access token. - """ - - return await self.db_pool.runInteraction( - "set_device_for_access_token", - self._set_device_for_access_token_txn, - token, - device_id, - ) - async def register_user( self, user_id: str, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 29a001ff929..5edac56ec3c 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -324,7 +324,7 @@ def _get_recent_references_for_event_txn( account_data_key=0, push_rules_key=0, to_device_key=0, - device_list_key=0, + device_list_key=MultiWriterStreamToken(stream=0), groups_key=0, un_partial_stated_rooms_key=0, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 3fda49f31f1..3ed3137f435 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -61,7 +61,6 @@ ) import attr -from immutabledict import immutabledict from typing_extensions import assert_never from twisted.internet import defer @@ -654,23 +653,7 @@ def get_room_max_token(self) -> RoomStreamToken: component. """ - min_pos = self._stream_id_gen.get_current_token() - - positions = {} - if isinstance(self._stream_id_gen, MultiWriterIdGenerator): - # The `min_pos` is the minimum position that we know all instances - # have finished persisting to, so we only care about instances whose - # positions are ahead of that. (Instance positions can be behind the - # min position as there are times we can work out that the minimum - # position is ahead of the naive minimum across all current - # positions. See MultiWriterIdGenerator for details) - positions = { - i: p - for i, p in self._stream_id_gen.get_positions().items() - if p > min_pos - } - - return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions)) + return RoomStreamToken.from_generator(self._stream_id_gen) def get_events_stream_id_generator(self) -> MultiWriterIdGenerator: return self._stream_id_gen diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 856f646795c..4534068e7c9 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -203,7 +203,7 @@ async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: account_data_key=0, push_rules_key=0, to_device_key=0, - device_list_key=0, + device_list_key=MultiWriterStreamToken(stream=0), groups_key=0, un_partial_stated_rooms_key=0, ) diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 5549f3c9f8d..d09fd30e814 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -75,6 +75,7 @@ from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore + from synapse.storage.util.id_generators import MultiWriterIdGenerator logger = logging.getLogger(__name__) @@ -570,6 +571,25 @@ def bound_stream_token(self, max_stream: int) -> "Self": ), ) + @classmethod + def from_generator(cls, generator: "MultiWriterIdGenerator") -> Self: + """Get the current token out of a MultiWriterIdGenerator""" + + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + min_pos = generator.get_current_token() + positions = { + instance: position + for instance, position in generator.get_positions().items() + if position > min_pos + } + + return cls(stream=min_pos, instance_map=immutabledict(positions)) + @attr.s(frozen=True, slots=True, order=False) class RoomStreamToken(AbstractMultiWriterStreamToken): @@ -980,7 +1000,9 @@ class StreamToken: account_data_key: int push_rules_key: int to_device_key: int - device_list_key: int + device_list_key: MultiWriterStreamToken = attr.ib( + validator=attr.validators.instance_of(MultiWriterStreamToken) + ) # Note that the groups key is no longer used and may have bogus values. groups_key: int un_partial_stated_rooms_key: int @@ -1021,7 +1043,9 @@ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": account_data_key=int(account_data_key), push_rules_key=int(push_rules_key), to_device_key=int(to_device_key), - device_list_key=int(device_list_key), + device_list_key=await MultiWriterStreamToken.parse( + store, device_list_key + ), groups_key=int(groups_key), un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), ) @@ -1040,7 +1064,7 @@ async def to_string(self, store: "DataStore") -> str: str(self.account_data_key), str(self.push_rules_key), str(self.to_device_key), - str(self.device_list_key), + await self.device_list_key.to_string(store), # Note that the groups key is no longer used, but it is still # serialized so that there will not be confusion in the future # if additional tokens are added. @@ -1069,6 +1093,12 @@ def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken": StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value) ) return new_token + elif key == StreamKeyType.DEVICE_LIST: + new_token = self.copy_and_replace( + StreamKeyType.DEVICE_LIST, + self.device_list_key.copy_and_advance(new_value), + ) + return new_token new_token = self.copy_and_replace(key, new_value) new_id = new_token.get_field(key) @@ -1087,7 +1117,11 @@ def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: ... @overload def get_field( - self, key: Literal[StreamKeyType.RECEIPT] + self, + key: Literal[ + StreamKeyType.RECEIPT, + StreamKeyType.DEVICE_LIST, + ], ) -> MultiWriterStreamToken: ... @overload @@ -1095,7 +1129,6 @@ def get_field( self, key: Literal[ StreamKeyType.ACCOUNT_DATA, - StreamKeyType.DEVICE_LIST, StreamKeyType.PRESENCE, StreamKeyType.PUSH_RULES, StreamKeyType.TO_DEVICE, @@ -1161,7 +1194,16 @@ def __str__(self) -> str: StreamToken.START = StreamToken( - RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0 + room_key=RoomStreamToken(stream=0), + presence_key=0, + typing_key=0, + receipt_key=MultiWriterStreamToken(stream=0), + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=MultiWriterStreamToken(stream=0), + groups_key=0, + un_partial_stated_rooms_key=0, ) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index cd906bbbc78..9931b96a6ac 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -30,7 +30,7 @@ from synapse.api.presence import UserPresenceState from synapse.federation.sender.per_destination_queue import MAX_PRESENCE_STATES_PER_EDU from synapse.federation.units import Transaction -from synapse.handlers.device import DeviceHandler +from synapse.handlers.device import DeviceHandler, DeviceListUpdater from synapse.rest import admin from synapse.rest.client import login from synapse.server import HomeServer @@ -554,6 +554,8 @@ def test_dont_send_device_updates_for_remote_users(self) -> None: "devices": [{"device_id": "D1"}], } + assert isinstance(self.device_handler.device_list_updater, DeviceListUpdater) + self.get_success( self.device_handler.device_list_updater.incoming_device_list_update( "host2",