diff --git a/changelog.d/18695.feature b/changelog.d/18695.feature new file mode 100644 index 00000000000..1481a27f237 --- /dev/null +++ b/changelog.d/18695.feature @@ -0,0 +1 @@ +Add experimental support for [MSC4308: Thread Subscriptions extension to Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4308) when [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/pull/4306) and [MSC4186: Simplified Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4186) are enabled. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c1631f39e3d..d086deab3f4 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -590,5 +590,5 @@ def read_config( self.msc4293_enabled: bool = experimental.get("msc4293_enabled", False) # MSC4306: Thread Subscriptions - # (and MSC4308: sliding sync extension for thread subscriptions) + # (and MSC4308: Thread Subscriptions extension to Sliding Sync) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 174d02ab6bb..c4905e63ddf 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -135,7 +135,7 @@ async def on_GET( if not self.allow_access: raise FederationDeniedError(origin) - limit = parse_integer_from_args(query, "limit", 0) + limit: Optional[int] = parse_integer_from_args(query, "limit", 0) since_token = parse_string_from_args(query, "since", None) include_all_networks = parse_boolean_from_args( query, "include_all_networks", default=False diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index 071a271ab7b..255a041d0eb 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -211,7 +211,7 @@ async def current_sync_for_user( Args: sync_config: Sync configuration - to_token: The point in the stream to sync up to. + to_token: The latest point in the stream to sync up to. from_token: The point in the stream to sync from. Token of the end of the previous batch. May be `None` if this is the initial sync request. """ diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 077887ec321..25ee954b7fd 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -27,7 +27,7 @@ cast, ) -from typing_extensions import assert_never +from typing_extensions import TypeAlias, assert_never from synapse.api.constants import AccountDataTypes, EduTypes from synapse.handlers.receipts import ReceiptEventSource @@ -40,6 +40,7 @@ SlidingSyncStreamToken, StrCollection, StreamToken, + ThreadSubscriptionsToken, ) from synapse.types.handlers.sliding_sync import ( HaveSentRoomFlag, @@ -54,6 +55,13 @@ gather_optional_coroutines, ) +_ThreadSubscription: TypeAlias = ( + SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription +) +_ThreadUnsubscription: TypeAlias = ( + SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription +) + if TYPE_CHECKING: from synapse.server import HomeServer @@ -68,6 +76,7 @@ def __init__(self, hs: "HomeServer"): self.event_sources = hs.get_event_sources() self.device_handler = hs.get_device_handler() self.push_rules_handler = hs.get_push_rules_handler() + self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled @trace async def get_extensions_response( @@ -93,7 +102,7 @@ async def get_extensions_response( actual_room_ids: The actual room IDs in the the Sliding Sync response. actual_room_response_map: A map of room ID to room results in the the Sliding Sync response. - to_token: The point in the stream to sync up to. + to_token: The latest point in the stream to sync up to. from_token: The point in the stream to sync from. """ @@ -156,18 +165,32 @@ async def get_extensions_response( from_token=from_token, ) + thread_subs_coro = None + if ( + sync_config.extensions.thread_subscriptions is not None + and self._enable_thread_subscriptions + ): + thread_subs_coro = self.get_thread_subscriptions_extension_response( + sync_config=sync_config, + thread_subscriptions_request=sync_config.extensions.thread_subscriptions, + to_token=to_token, + from_token=from_token, + ) + ( to_device_response, e2ee_response, account_data_response, receipts_response, typing_response, + thread_subs_response, ) = await gather_optional_coroutines( to_device_coro, e2ee_coro, account_data_coro, receipts_coro, typing_coro, + thread_subs_coro, ) return SlidingSyncResult.Extensions( @@ -176,6 +199,7 @@ async def get_extensions_response( account_data=account_data_response, receipts=receipts_response, typing=typing_response, + thread_subscriptions=thread_subs_response, ) def find_relevant_room_ids_for_extension( @@ -877,3 +901,72 @@ async def get_typing_extension_response( return SlidingSyncResult.Extensions.TypingExtension( room_id_to_typing_map=room_id_to_typing_map, ) + + async def get_thread_subscriptions_extension_response( + self, + sync_config: SlidingSyncConfig, + thread_subscriptions_request: SlidingSyncConfig.Extensions.ThreadSubscriptionsExtension, + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.ThreadSubscriptionsExtension]: + """Handle Thread Subscriptions extension (MSC4308) + + Args: + sync_config: Sync configuration + thread_subscriptions_request: The thread_subscriptions extension from the request + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + + Returns: + the response (None if empty or thread subscriptions are disabled) + """ + if not thread_subscriptions_request.enabled: + return None + + limit = thread_subscriptions_request.limit + + if from_token: + from_stream_id = from_token.stream_token.thread_subscriptions_key + else: + from_stream_id = StreamToken.START.thread_subscriptions_key + + to_stream_id = to_token.thread_subscriptions_key + + updates = await self.store.get_latest_updated_thread_subscriptions_for_user( + user_id=sync_config.user.to_string(), + from_id=from_stream_id, + to_id=to_stream_id, + limit=limit, + ) + + if len(updates) == 0: + return None + + subscribed_threads: Dict[str, Dict[str, _ThreadSubscription]] = {} + unsubscribed_threads: Dict[str, Dict[str, _ThreadUnsubscription]] = {} + for stream_id, room_id, thread_root_id, subscribed, automatic in updates: + if subscribed: + subscribed_threads.setdefault(room_id, {})[thread_root_id] = ( + _ThreadSubscription( + automatic=automatic, + bump_stamp=stream_id, + ) + ) + else: + unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = ( + _ThreadUnsubscription(bump_stamp=stream_id) + ) + + prev_batch = None + if len(updates) == limit: + # Tell the client about a potential gap where there may be more + # thread subscriptions for it to backpaginate. + # We subtract one because the 'later in the stream' bound is inclusive, + # and we already saw the element at index 0. + prev_batch = ThreadSubscriptionsToken(updates[0][0] - 1) + + return SlidingSyncResult.Extensions.ThreadSubscriptionsExtension( + subscribed=subscribed_threads, + unsubscribed=unsubscribed_threads, + prev_batch=prev_batch, + ) diff --git a/synapse/handlers/thread_subscriptions.py b/synapse/handlers/thread_subscriptions.py index bda43429491..d56c915e0a5 100644 --- a/synapse/handlers/thread_subscriptions.py +++ b/synapse/handlers/thread_subscriptions.py @@ -9,7 +9,7 @@ AutomaticSubscriptionConflicted, ThreadSubscription, ) -from synapse.types import EventOrderings, UserID +from synapse.types import EventOrderings, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -22,6 +22,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() + self._notifier = hs.get_notifier() async def get_thread_subscription_settings( self, @@ -132,6 +133,15 @@ async def subscribe_user_to_thread( errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION, ) + if outcome is not None: + # wake up user streams (e.g. sliding sync) on the same worker + self._notifier.on_new_event( + StreamKeyType.THREAD_SUBSCRIPTIONS, + # outcome is a stream_id + outcome, + users=[user_id.to_string()], + ) + return outcome async def unsubscribe_user_from_thread( @@ -162,8 +172,19 @@ async def unsubscribe_user_from_thread( logger.info("rejecting thread subscriptions change (thread not accessible)") raise NotFoundError("No such thread root") - return await self.store.unsubscribe_user_from_thread( + outcome = await self.store.unsubscribe_user_from_thread( user_id.to_string(), event.room_id, thread_root_event_id, ) + + if outcome is not None: + # wake up user streams (e.g. sliding sync) on the same worker + self._notifier.on_new_event( + StreamKeyType.THREAD_SUBSCRIPTIONS, + # outcome is a stream_id + outcome, + users=[user_id.to_string()], + ) + + return outcome diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 47d8bd5eaf1..69bdce2b834 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -130,6 +130,16 @@ def parse_integer( return parse_integer_from_args(args, name, default, required, negative) +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: int, + required: Literal[False] = False, + negative: bool = False, +) -> int: ... + + @overload def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], diff --git a/synapse/notifier.py b/synapse/notifier.py index 7782c9ca659..e684df4866b 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -532,6 +532,7 @@ def on_new_event( StreamKeyType.TO_DEVICE, StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, + StreamKeyType.THREAD_SUBSCRIPTIONS, ], new_token: int, users: Optional[Collection[Union[str, UserID]]] = None, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index ee9250cf7d5..7a86b2e65ee 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -44,6 +44,7 @@ UnPartialStatedEventStream, UnPartialStatedRoomStream, ) +from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -255,6 +256,12 @@ async def on_rdata( self._state_storage_controller.notify_event_un_partial_stated( row.event_id ) + elif stream_name == ThreadSubscriptionsStream.NAME: + self.notifier.on_new_event( + StreamKeyType.THREAD_SUBSCRIPTIONS, + token, + users=[row.user_id for row in rows], + ) await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 6f2f6642bed..c424ca53254 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -23,6 +23,8 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union +import attr + from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection @@ -632,12 +634,21 @@ async def encode_room( class SlidingSyncRestServlet(RestServlet): """ - API endpoint for MSC3575 Sliding Sync `/sync`. Allows for clients to request a + API endpoint for MSC4186 Simplified Sliding Sync `/sync`, which was historically derived + from MSC3575 (Sliding Sync; now abandoned). Allows for clients to request a subset (sliding window) of rooms, state, and timeline events (just what they need) in order to bootstrap quickly and subscribe to only what the client cares about. Because the client can specify what it cares about, we can respond quickly and skip all of the work we would normally have to do with a sync v2 response. + Extensions of various features are defined in: + - to-device messaging (MSC3885) + - end-to-end encryption (MSC3884) + - typing notifications (MSC3961) + - receipts (MSC3960) + - account data (MSC3959) + - thread subscriptions (MSC4308) + Request query parameters: timeout: How long to wait for new events in milliseconds. pos: Stream position token when asking for incremental deltas. @@ -1074,9 +1085,48 @@ async def encode_extensions( "rooms": extensions.typing.room_id_to_typing_map, } + # excludes both None and falsy `thread_subscriptions` + if extensions.thread_subscriptions: + serialized_extensions["io.element.msc4308.thread_subscriptions"] = ( + _serialise_thread_subscriptions(extensions.thread_subscriptions) + ) + return serialized_extensions +def _serialise_thread_subscriptions( + thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension, +) -> JsonDict: + out: JsonDict = {} + + if thread_subscriptions.subscribed: + out["subscribed"] = { + room_id: { + thread_root_id: attr.asdict( + change, filter=lambda _attr, v: v is not None + ) + for thread_root_id, change in room_threads.items() + } + for room_id, room_threads in thread_subscriptions.subscribed.items() + } + + if thread_subscriptions.unsubscribed: + out["unsubscribed"] = { + room_id: { + thread_root_id: attr.asdict( + change, filter=lambda _attr, v: v is not None + ) + for thread_root_id, change in room_threads.items() + } + for room_id, room_threads in thread_subscriptions.unsubscribed.items() + } + + if thread_subscriptions.prev_batch: + out["prev_batch"] = thread_subscriptions.prev_batch.to_string() + + return out + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/thread_subscriptions.py b/synapse/rest/client/thread_subscriptions.py index 4e7b5d06dbe..039aba1721c 100644 --- a/synapse/rest/client/thread_subscriptions.py +++ b/synapse/rest/client/thread_subscriptions.py @@ -1,21 +1,39 @@ from http import HTTPStatus -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import attr +from typing_extensions import TypeAlias from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_and_validate_json_object_from_request, + parse_integer, + parse_string, ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, RoomID +from synapse.types import ( + JsonDict, + RoomID, + SlidingSyncStreamToken, + ThreadSubscriptionsToken, +) +from synapse.types.handlers.sliding_sync import SlidingSyncResult from synapse.types.rest import RequestBodyModel from synapse.util.pydantic_models import AnyEventId if TYPE_CHECKING: from synapse.server import HomeServer +_ThreadSubscription: TypeAlias = ( + SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription +) +_ThreadUnsubscription: TypeAlias = ( + SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription +) + class ThreadSubscriptionsRestServlet(RestServlet): PATTERNS = client_patterns( @@ -100,6 +118,130 @@ async def on_DELETE( return HTTPStatus.OK, {} +class ThreadSubscriptionsPaginationRestServlet(RestServlet): + PATTERNS = client_patterns( + "/io.element.msc4308/thread_subscriptions$", + unstable=True, + releases=(), + ) + CATEGORY = "Thread Subscriptions requests (unstable)" + + # Maximum number of thread subscriptions to return in one request. + MAX_LIMIT = 512 + + def __init__(self, hs: "HomeServer"): + self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastores().main + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = min( + parse_integer(request, "limit", default=100, negative=False), + ThreadSubscriptionsPaginationRestServlet.MAX_LIMIT, + ) + from_end_opt = parse_string(request, "from", required=False) + to_start_opt = parse_string(request, "to", required=False) + _direction = parse_string(request, "dir", required=True, allowed_values=("b",)) + + if limit <= 0: + # condition needed because `negative=False` still allows 0 + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "limit must be greater than 0", + errcode=Codes.INVALID_PARAM, + ) + + if from_end_opt is not None: + try: + # because of backwards pagination, the `from` token is actually the + # bound closest to the end of the stream + end_stream_id = ThreadSubscriptionsToken.from_string( + from_end_opt + ).stream_id + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "`from` is not a valid token", + errcode=Codes.INVALID_PARAM, + ) + else: + end_stream_id = self.store.get_max_thread_subscriptions_stream_id() + + if to_start_opt is not None: + # because of backwards pagination, the `to` token is actually the + # bound closest to the start of the stream + try: + start_stream_id = ThreadSubscriptionsToken.from_string( + to_start_opt + ).stream_id + except ValueError: + # we also accept sliding sync `pos` tokens on this parameter + try: + sliding_sync_pos = await SlidingSyncStreamToken.from_string( + self.store, to_start_opt + ) + start_stream_id = ( + sliding_sync_pos.stream_token.thread_subscriptions_key + ) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "`to` is not a valid token", + errcode=Codes.INVALID_PARAM, + ) + else: + # the start of time is ID 1; the lower bound is exclusive though + start_stream_id = 0 + + subscriptions = ( + await self.store.get_latest_updated_thread_subscriptions_for_user( + requester.user.to_string(), + from_id=start_stream_id, + to_id=end_stream_id, + limit=limit, + ) + ) + + subscribed_threads: Dict[str, Dict[str, JsonDict]] = {} + unsubscribed_threads: Dict[str, Dict[str, JsonDict]] = {} + for stream_id, room_id, thread_root_id, subscribed, automatic in subscriptions: + if subscribed: + subscribed_threads.setdefault(room_id, {})[thread_root_id] = ( + attr.asdict( + _ThreadSubscription( + automatic=automatic, + bump_stamp=stream_id, + ) + ) + ) + else: + unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = ( + attr.asdict(_ThreadUnsubscription(bump_stamp=stream_id)) + ) + + result: JsonDict = {} + if subscribed_threads: + result["subscribed"] = subscribed_threads + if unsubscribed_threads: + result["unsubscribed"] = unsubscribed_threads + + if len(subscriptions) == limit: + # We hit the limit, so there might be more entries to return. + # Generate a new token that has moved backwards, ready for the next + # request. + min_returned_stream_id, _, _, _, _ = subscriptions[0] + result["end"] = ThreadSubscriptionsToken( + # We subtract one because the 'later in the stream' bound is inclusive, + # and we already saw the element at index 0. + stream_id=min_returned_stream_id - 1 + ).to_string() + + return HTTPStatus.OK, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc4306_enabled: ThreadSubscriptionsRestServlet(hs).register(http_server) + ThreadSubscriptionsPaginationRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 5edac56ec3c..ea746e05118 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -53,7 +53,7 @@ generate_pagination_where_clause, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken +from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -316,17 +316,8 @@ def _get_recent_references_for_event_txn( StreamKeyType.ROOM, next_key ) else: - next_token = StreamToken( - room_key=next_key, - presence_key=0, - typing_key=0, - receipt_key=MultiWriterStreamToken(stream=0), - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=MultiWriterStreamToken(stream=0), - groups_key=0, - un_partial_stated_rooms_key=0, + next_token = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, next_key ) return events[:limit], next_token diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py index 6a62b11d1ed..72ec8e6b909 100644 --- a/synapse/storage/databases/main/sliding_sync.py +++ b/synapse/storage/databases/main/sliding_sync.py @@ -492,7 +492,7 @@ class PerConnectionStateDB: """An equivalent to `PerConnectionState` that holds data in a format stored in the DB. - The principle difference is that the tokens for the different streams are + The principal difference is that the tokens for the different streams are serialized to strings. When persisting this *only* contains updates to the state. diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py index 24a99cf4490..50084887a4e 100644 --- a/synapse/storage/databases/main/thread_subscriptions.py +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -505,6 +505,9 @@ def get_max_thread_subscriptions_stream_id(self) -> int: """ return self._thread_subscriptions_id_gen.get_current_token() + def get_thread_subscriptions_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._thread_subscriptions_id_gen + async def get_updated_thread_subscriptions( self, *, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str, str]]: @@ -538,34 +541,52 @@ def get_updated_thread_subscriptions_txn( get_updated_thread_subscriptions_txn, ) - async def get_updated_thread_subscriptions_for_user( + async def get_latest_updated_thread_subscriptions_for_user( self, user_id: str, *, from_id: int, to_id: int, limit: int - ) -> List[Tuple[int, str, str]]: - """Get updates to thread subscriptions for a specific user. + ) -> List[Tuple[int, str, str, bool, Optional[bool]]]: + """Get the latest updates to thread subscriptions for a specific user. Args: user_id: The ID of the user from_id: The starting stream ID (exclusive) to_id: The ending stream ID (inclusive) limit: The maximum number of rows to return + If there are too many rows to return, rows from the start (closer to `from_id`) + will be omitted. Returns: - A list of (stream_id, room_id, thread_root_event_id) tuples. + A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples. + The row with lowest `stream_id` is the first row. """ def get_updated_thread_subscriptions_for_user_txn( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str]]: + ) -> List[Tuple[int, str, str, bool, Optional[bool]]]: sql = """ - SELECT stream_id, room_id, event_id - FROM thread_subscriptions - WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + WITH the_updates AS ( + SELECT stream_id, room_id, event_id, subscribed, automatic + FROM thread_subscriptions + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id DESC + LIMIT ? + ) + SELECT stream_id, room_id, event_id, subscribed, automatic + FROM the_updates ORDER BY stream_id ASC - LIMIT ? """ txn.execute(sql, (user_id, from_id, to_id, limit)) - return [(row[0], row[1], row[2]) for row in txn] + return [ + ( + stream_id, + room_id, + event_id, + # SQLite integer to boolean conversions + bool(subscribed), + bool(automatic) if subscribed else None, + ) + for (stream_id, room_id, event_id, subscribed, automatic) in txn + ] return await self.db_pool.runInteraction( "get_updated_thread_subscriptions_for_user", diff --git a/synapse/storage/schema/main/delta/92/08_thread_subscriptions_seq_fixup.sql.postgres b/synapse/storage/schema/main/delta/92/08_thread_subscriptions_seq_fixup.sql.postgres new file mode 100644 index 00000000000..d327d1e1654 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/08_thread_subscriptions_seq_fixup.sql.postgres @@ -0,0 +1,19 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Work around https://github.com/element-hq/synapse/issues/18712 by advancing the +-- stream sequence. +-- This makes last_value of the sequence point to a position that will not get later +-- returned by nextval. +-- (For blank thread subscription streams, this means last_value = 2, nextval() = 3 after this line.) +SELECT nextval('thread_subscriptions_sequence'); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a15a161ce85..1b7c5dac7a2 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -187,8 +187,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): Warning: Streams using this generator start at ID 2, because ID 1 is always assumed to have been 'seen as persisted'. Unclear if this extant behaviour is desirable for some reason. - When creating a new sequence for a new stream, - it will be necessary to use `START WITH 2`. + When creating a new sequence for a new stream, it will be necessary to advance it + so that position 1 is consumed. + DO NOT USE `START WITH 2` FOR THIS PURPOSE: + see https://github.com/element-hq/synapse/issues/18712 + Instead, use `SELECT nextval('sequence_name');` immediately after the + `CREATE SEQUENCE` statement. Args: db_conn diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4534068e7c9..1e4bebe46d5 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -33,7 +33,6 @@ from synapse.streams import EventSource from synapse.types import ( AbstractMultiWriterStreamToken, - MultiWriterStreamToken, StreamKeyType, StreamToken, ) @@ -84,6 +83,7 @@ def get_current_token(self) -> StreamToken: un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token( self._instance_name ) + thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -97,6 +97,7 @@ def get_current_token(self) -> StreamToken: # Groups key is unused. groups_key=0, un_partial_stated_rooms_key=un_partial_stated_rooms_key, + thread_subscriptions_key=thread_subscriptions_key, ) return token @@ -123,6 +124,7 @@ async def bound_future_token(self, token: StreamToken) -> StreamToken: StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(), StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(), StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), + StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(), } for _, key in StreamKeyType.__members__.items(): @@ -195,16 +197,7 @@ async def get_current_token_for_pagination(self, room_id: str) -> StreamToken: Returns: The current token for pagination. """ - token = StreamToken( - room_key=await self.sources.room.get_current_key_for_room(room_id), - presence_key=0, - typing_key=0, - receipt_key=MultiWriterStreamToken(stream=0), - account_data_key=0, - push_rules_key=0, - to_device_key=0, - device_list_key=MultiWriterStreamToken(stream=0), - groups_key=0, - un_partial_stated_rooms_key=0, + return StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, + await self.sources.room.get_current_key_for_room(room_id), ) - return token diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 943f211b118..2d5b07ab8fa 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -996,6 +996,7 @@ class StreamKeyType(Enum): TO_DEVICE = "to_device_key" DEVICE_LIST = "device_list_key" UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" + THREAD_SUBSCRIPTIONS = "thread_subscriptions_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1003,7 +1004,7 @@ class StreamToken: """A collection of keys joined together by underscores in the following order and which represent the position in their respective streams. - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379` + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242` 1. `room_key`: `s2633508` which is a `RoomStreamToken` - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - See the docstring for `RoomStreamToken` for more details. @@ -1016,6 +1017,7 @@ class StreamToken: 8. `device_list_key`: `265584` 9. `groups_key`: `1` (note that this key is now unused) 10. `un_partial_stated_rooms_key`: `379` + 11. `thread_subscriptions_key`: 4242 You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -1074,6 +1076,7 @@ class StreamToken: # Note that the groups key is no longer used and may have bogus values. groups_key: int un_partial_stated_rooms_key: int + thread_subscriptions_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -1101,6 +1104,7 @@ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": device_list_key, groups_key, un_partial_stated_rooms_key, + thread_subscriptions_key, ) = keys return cls( @@ -1116,6 +1120,7 @@ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": ), groups_key=int(groups_key), un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), + thread_subscriptions_key=int(thread_subscriptions_key), ) except CancelledError: raise @@ -1138,6 +1143,7 @@ async def to_string(self, store: "DataStore") -> str: # if additional tokens are added. str(self.groups_key), str(self.un_partial_stated_rooms_key), + str(self.thread_subscriptions_key), ] ) @@ -1202,6 +1208,7 @@ def get_field( StreamKeyType.TO_DEVICE, StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, + StreamKeyType.THREAD_SUBSCRIPTIONS, ], ) -> int: ... @@ -1257,7 +1264,8 @@ def __str__(self) -> str: f"typing: {self.typing_key}, receipt: {self.receipt_key}, " f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, " f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, " - f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key})" + f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}," + f"thread_subscriptions: {self.thread_subscriptions_key})" ) @@ -1272,6 +1280,7 @@ def __str__(self) -> str: device_list_key=MultiWriterStreamToken(stream=0), groups_key=0, un_partial_stated_rooms_key=0, + thread_subscriptions_key=0, ) @@ -1318,6 +1327,27 @@ async def to_string(self, store: "DataStore") -> str: return f"{self.connection_position}/{stream_token_str}" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadSubscriptionsToken: + """ + Token for a position in the thread subscriptions stream. + + Format: `ts` + """ + + stream_id: int + + @staticmethod + def from_string(s: str) -> "ThreadSubscriptionsToken": + if not s.startswith("ts"): + raise ValueError("thread subscription token must start with `ts`") + + return ThreadSubscriptionsToken(stream_id=int(s[2:])) + + def to_string(self) -> str: + return f"ts{self.stream_id}" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class PersistedPosition: """Position of a newly persisted row with instance that persisted it.""" diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index 3ebd334a6d5..b7bc565464f 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -50,6 +50,7 @@ SlidingSyncStreamToken, StrCollection, StreamToken, + ThreadSubscriptionsToken, UserID, ) from synapse.types.rest.client import SlidingSyncBody @@ -357,11 +358,50 @@ class TypingExtension: def __bool__(self) -> bool: return bool(self.room_id_to_typing_map) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadSubscriptionsExtension: + """The Thread Subscriptions extension (MSC4308) + + Attributes: + subscribed: map (room_id -> thread_root_id -> info) of new or changed subscriptions + unsubscribed: map (room_id -> thread_root_id -> info) of new unsubscriptions + prev_batch: if present, there is a gap and the client can use this token to backpaginate + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadSubscription: + # always present when `subscribed` + automatic: Optional[bool] + + # the same as our stream_id; useful for clients to resolve + # race conditions locally + bump_stamp: int + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadUnsubscription: + # the same as our stream_id; useful for clients to resolve + # race conditions locally + bump_stamp: int + + # room_id -> event_id (of thread root) -> the subscription change + subscribed: Optional[Mapping[str, Mapping[str, ThreadSubscription]]] + # room_id -> event_id (of thread root) -> the unsubscription + unsubscribed: Optional[Mapping[str, Mapping[str, ThreadUnsubscription]]] + prev_batch: Optional[ThreadSubscriptionsToken] + + def __bool__(self) -> bool: + return ( + bool(self.subscribed) + or bool(self.unsubscribed) + or bool(self.prev_batch) + ) + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None receipts: Optional[ReceiptsExtension] = None typing: Optional[TypingExtension] = None + thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None def __bool__(self) -> bool: return bool( @@ -370,6 +410,7 @@ def __bool__(self) -> bool: or self.account_data or self.receipts or self.typing + or self.thread_subscriptions ) next_pos: SlidingSyncStreamToken diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index c739bd16b0c..11d7e59b43a 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -22,6 +22,7 @@ from synapse._pydantic_compat import ( Extra, + Field, StrictBool, StrictInt, StrictStr, @@ -364,11 +365,25 @@ class TypingExtension(RequestBodyModel): # Process all room subscriptions defined in the Room Subscription API. (This is the default.) rooms: Optional[List[StrictStr]] = ["*"] + class ThreadSubscriptionsExtension(RequestBodyModel): + """The Thread Subscriptions extension (MSC4308) + + Attributes: + enabled + limit: maximum number of subscription changes to return (default 100) + """ + + enabled: Optional[StrictBool] = False + limit: StrictInt = 100 + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None receipts: Optional[ReceiptsExtension] = None typing: Optional[TypingExtension] = None + thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field( + alias="io.element.msc4308.thread_subscriptions" + ) conn_id: Optional[StrictStr] diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index e596e1ed209..c21b7887f9e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -347,6 +347,7 @@ async def yieldable_gather_results_delaying_cancellation( T3 = TypeVar("T3") T4 = TypeVar("T4") T5 = TypeVar("T5") +T6 = TypeVar("T6") @overload @@ -461,6 +462,23 @@ async def gather_optional_coroutines( ) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ... +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + Optional[Coroutine[Any, Any, T5]], + Optional[Coroutine[Any, Any, T6]], + ] + ], +) -> Tuple[ + Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5], Optional[T6] +]: ... + + async def gather_optional_coroutines( *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]], ) -> Tuple[Optional[T1], ...]: diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index b98c53891cb..ee5d0419ab9 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2244,7 +2244,7 @@ def test_timestamp_to_event(self) -> None: def test_topo_token_is_accepted(self) -> None: """Test Topo Token is accepted.""" - token = "t1-0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), @@ -2258,7 +2258,7 @@ def test_topo_token_is_accepted(self) -> None: def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: """Test that stream token is accepted for forward pagination.""" - token = "s0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), diff --git a/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py b/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py new file mode 100644 index 00000000000..775c4f96c91 --- /dev/null +++ b/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py @@ -0,0 +1,497 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +import logging +from http import HTTPStatus +from typing import List, Optional, Tuple, cast + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, room, sync, thread_subscriptions +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase + +logger = logging.getLogger(__name__) + + +# The name of the extension. Currently unstable-prefixed. +EXT_NAME = "io.element.msc4308.thread_subscriptions" + + +class SlidingSyncThreadSubscriptionsExtensionTestCase(SlidingSyncBase): + """ + Test the thread subscriptions extension in the Sliding Sync API. + """ + + maxDiff = None + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + thread_subscriptions.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4306_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() + super().prepare(reactor, clock, hs) + + def test_no_data_initial_sync(self) -> None: + """ + Test enabling thread subscriptions extension during initial sync with no data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertNotIn(EXT_NAME, response_body["extensions"]) + + def test_no_data_incremental_sync(self) -> None: + """ + Test enabling thread subscriptions extension during incremental sync with no data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + initial_sync_body: JsonDict = { + "lists": {}, + } + + # Initial sync + response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok) + + # Incremental sync with extension enabled + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert + self.assertNotIn( + EXT_NAME, + response_body["extensions"], + response_body, + ) + + def test_thread_subscription_initial_sync(self) -> None: + """ + Test thread subscriptions appear in initial sync response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + self._subscribe_to_thread(user1_id, room_id, thread_root_id) + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + { + "subscribed": { + room_id: { + thread_root_id: { + "automatic": False, + "bump_stamp": base + 1, + } + } + } + }, + ) + + def test_thread_subscription_incremental_sync(self) -> None: + """ + Test new thread subscriptions appear in incremental sync response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + # Initial sync + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + logger.info("Synced to: %r, now subscribing to thread", sync_pos) + + # Subscribe + self._subscribe_to_thread(user1_id, room_id, thread_root_id) + + # Incremental sync + response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + logger.info("Synced to: %r", sync_pos) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + { + "subscribed": { + room_id: { + thread_root_id: { + "automatic": False, + "bump_stamp": base + 1, + } + } + } + }, + ) + + def test_unsubscribe_from_thread(self) -> None: + """ + Test unsubscribing from a thread. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + self._subscribe_to_thread(user1_id, room_id, thread_root_id) + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Assert: Subscription present + self.assertIn(EXT_NAME, response_body["extensions"]) + self.assertEqual( + response_body["extensions"][EXT_NAME], + { + "subscribed": { + room_id: { + thread_root_id: {"automatic": False, "bump_stamp": base + 1} + } + } + }, + ) + + # Unsubscribe + self._unsubscribe_from_thread(user1_id, room_id, thread_root_id) + + # Incremental sync + response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: Unsubscription present + self.assertEqual( + response_body["extensions"][EXT_NAME], + {"unsubscribed": {room_id: {thread_root_id: {"bump_stamp": base + 2}}}}, + ) + + def test_multiple_thread_subscriptions(self) -> None: + """ + Test handling of multiple thread subscriptions. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread roots + thread_root_resp1 = self.helper.send( + room_id, body="Thread root 1", tok=user1_tok + ) + thread_root_id1 = thread_root_resp1["event_id"] + thread_root_resp2 = self.helper.send( + room_id, body="Thread root 2", tok=user1_tok + ) + thread_root_id2 = thread_root_resp2["event_id"] + thread_root_resp3 = self.helper.send( + room_id, body="Thread root 3", tok=user1_tok + ) + thread_root_id3 = thread_root_resp3["event_id"] + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + # Subscribe to threads + self._subscribe_to_thread(user1_id, room_id, thread_root_id1) + self._subscribe_to_thread(user1_id, room_id, thread_root_id2) + self._subscribe_to_thread(user1_id, room_id, thread_root_id3) + + sync_body = { + "lists": {}, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + { + "subscribed": { + room_id: { + thread_root_id1: { + "automatic": False, + "bump_stamp": base + 1, + }, + thread_root_id2: { + "automatic": False, + "bump_stamp": base + 2, + }, + thread_root_id3: { + "automatic": False, + "bump_stamp": base + 3, + }, + } + } + }, + ) + + def test_limit_parameter(self) -> None: + """ + Test limit parameter in thread subscriptions extension. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create 5 thread roots and subscribe to each + thread_root_ids = [] + for i in range(5): + thread_root_resp = self.helper.send( + room_id, body=f"Thread root {i}", tok=user1_tok + ) + thread_root_ids.append(thread_root_resp["event_id"]) + self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1]) + + sync_body = { + "lists": {}, + "extensions": {EXT_NAME: {"enabled": True, "limit": 3}}, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + thread_subscriptions = response_body["extensions"][EXT_NAME] + self.assertEqual( + len(thread_subscriptions["subscribed"][room_id]), 3, thread_subscriptions + ) + + def test_limit_and_companion_backpagination(self) -> None: + """ + Create 1 thread subscription, do a sync, create 4 more, + then sync with a limit of 2 and fill in the gap + using the companion /thread_subscriptions endpoint. + """ + + thread_root_ids: List[str] = [] + + def make_subscription() -> None: + thread_root_resp = self.helper.send( + room_id, body="Some thread root", tok=user1_tok + ) + thread_root_ids.append(thread_root_resp["event_id"]) + self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1]) + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # get the baseline stream_id of the thread_subscriptions stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + base = self.store.get_max_thread_subscriptions_stream_id() + + # Make our first subscription + make_subscription() + + # Sync for the first time + sync_body = { + "lists": {}, + "extensions": {EXT_NAME: {"enabled": True, "limit": 2}}, + } + + sync_resp, first_sync_pos = self.do_sync(sync_body, tok=user1_tok) + + thread_subscriptions = sync_resp["extensions"][EXT_NAME] + self.assertEqual( + thread_subscriptions["subscribed"], + { + room_id: { + thread_root_ids[0]: {"automatic": False, "bump_stamp": base + 1}, + } + }, + ) + + # Get our pos for the next sync + first_sync_pos = sync_resp["pos"] + + # Create 5 more thread subscriptions and subscribe to each + for _ in range(5): + make_subscription() + + # Now sync again. Our limit is 2, + # so we should get the latest 2 subscriptions, + # with a gap of 3 more subscriptions in the middle + sync_resp, _pos = self.do_sync(sync_body, tok=user1_tok, since=first_sync_pos) + + thread_subscriptions = sync_resp["extensions"][EXT_NAME] + self.assertEqual( + thread_subscriptions["subscribed"], + { + room_id: { + thread_root_ids[4]: {"automatic": False, "bump_stamp": base + 5}, + thread_root_ids[5]: {"automatic": False, "bump_stamp": base + 6}, + } + }, + ) + # 1st backpagination: expecting a page with 2 subscriptions + page, end_tok = self._do_backpaginate( + from_tok=thread_subscriptions["prev_batch"], + to_tok=first_sync_pos, + limit=2, + access_token=user1_tok, + ) + self.assertIsNotNone(end_tok, "backpagination should continue") + self.assertEqual( + page["subscribed"], + { + room_id: { + thread_root_ids[2]: {"automatic": False, "bump_stamp": base + 3}, + thread_root_ids[3]: {"automatic": False, "bump_stamp": base + 4}, + } + }, + ) + + # 2nd backpagination: expecting a page with only 1 subscription + # and no other token for further backpagination + assert end_tok is not None + page, end_tok = self._do_backpaginate( + from_tok=end_tok, to_tok=first_sync_pos, limit=2, access_token=user1_tok + ) + self.assertIsNone(end_tok, "backpagination should have finished") + self.assertEqual( + page["subscribed"], + { + room_id: { + thread_root_ids[1]: {"automatic": False, "bump_stamp": base + 2}, + } + }, + ) + + def _do_backpaginate( + self, *, from_tok: str, to_tok: str, limit: int, access_token: str + ) -> Tuple[JsonDict, Optional[str]]: + channel = self.make_request( + "GET", + "/_matrix/client/unstable/io.element.msc4308/thread_subscriptions" + f"?from={from_tok}&to={to_tok}&limit={limit}&dir=b", + access_token=access_token, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + body = channel.json_body + return body, cast(Optional[str], body.get("end")) + + def _subscribe_to_thread( + self, user_id: str, room_id: str, thread_root_id: str + ) -> None: + """ + Helper method to subscribe a user to a thread. + """ + self.get_success( + self.store.subscribe_user_to_thread( + user_id=user_id, + room_id=room_id, + thread_root_event_id=thread_root_id, + automatic_event_orderings=None, + ) + ) + + def _unsubscribe_from_thread( + self, user_id: str, room_id: str, thread_root_id: str + ) -> None: + """ + Helper method to unsubscribe a user from a thread. + """ + self.get_success( + self.store.unsubscribe_user_from_thread( + user_id=user_id, + room_id=room_id, + thread_root_event_id=thread_root_id, + ) + ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 24a28fbdd28..d3b5e26132d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2245,7 +2245,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self) -> None: - token = "t1-0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -2256,7 +2256,7 @@ def test_topo_token_is_accepted(self) -> None: self.assertTrue("end" in channel.json_body) def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: - token = "s0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py index 2a5c440cf49..2ce369247fb 100644 --- a/tests/storage/test_thread_subscriptions.py +++ b/tests/storage/test_thread_subscriptions.py @@ -189,19 +189,19 @@ def test_purge_thread_subscriptions_for_user(self) -> None: self._subscribe(self.other_thread_root_id, automatic_event_orderings=None) subscriptions = self.get_success( - self.store.get_updated_thread_subscriptions_for_user( + self.store.get_latest_updated_thread_subscriptions_for_user( self.user_id, from_id=0, to_id=50, limit=50, ) ) - min_id = min(id for (id, _, _) in subscriptions) + min_id = min(id for (id, _, _, _, _) in subscriptions) self.assertEqual( subscriptions, [ - (min_id, self.room_id, self.thread_root_id), - (min_id + 1, self.room_id, self.other_thread_root_id), + (min_id, self.room_id, self.thread_root_id, True, True), + (min_id + 1, self.room_id, self.other_thread_root_id, True, False), ], ) @@ -212,7 +212,7 @@ def test_purge_thread_subscriptions_for_user(self) -> None: # Check user has no subscriptions subscriptions = self.get_success( - self.store.get_updated_thread_subscriptions_for_user( + self.store.get_latest_updated_thread_subscriptions_for_user( self.user_id, from_id=0, to_id=50, @@ -280,20 +280,22 @@ def test_get_updated_thread_subscriptions_for_user(self) -> None: # Get updates for main user updates = self.get_success( - self.store.get_updated_thread_subscriptions_for_user( + self.store.get_latest_updated_thread_subscriptions_for_user( self.user_id, from_id=0, to_id=stream_id2, limit=10 ) ) - self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)]) + self.assertEqual( + updates, [(stream_id1, self.room_id, self.thread_root_id, True, True)] + ) # Get updates for other user updates = self.get_success( - self.store.get_updated_thread_subscriptions_for_user( + self.store.get_latest_updated_thread_subscriptions_for_user( other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10 ) ) self.assertEqual( - updates, [(stream_id2, self.room_id, self.other_thread_root_id)] + updates, [(stream_id2, self.room_id, self.other_thread_root_id, True, True)] ) def test_should_skip_autosubscription_after_unsubscription(self) -> None: