Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/19005.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add experimental support for MSC4360: Sliding Sync Threads Extension.
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,6 @@ def read_config(
# MSC4306: Thread Subscriptions
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)

# MSC4360: Threads Extension to Sliding Sync
self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False)
2 changes: 0 additions & 2 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ async def get_relations(
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.

TODO Accept a PaginationConfig instead of individual pagination parameters.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been done alread, the comment just wasn't updated.


Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
Expand Down
77 changes: 77 additions & 0 deletions synapse/handlers/sliding_sync/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand All @@ -76,7 +77,9 @@ def __init__(self, hs: "HomeServer"):
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self.relations_handler = hs.get_relations_handler()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
self._enable_threads_ext = hs.config.experimental.msc4360_enabled

@trace
async def get_extensions_response(
Expand Down Expand Up @@ -177,20 +180,31 @@ async def get_extensions_response(
from_token=from_token,
)

threads_coro = None
if sync_config.extensions.threads is not None and self._enable_threads_ext:
threads_coro = self.get_threads_extension_response(
sync_config=sync_config,
threads_request=sync_config.extensions.threads,
to_token=to_token,
from_token=from_token,
)

(
to_device_response,
e2ee_response,
account_data_response,
receipts_response,
typing_response,
thread_subs_response,
threads_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
account_data_coro,
receipts_coro,
typing_coro,
thread_subs_coro,
threads_coro,
)

return SlidingSyncResult.Extensions(
Expand All @@ -200,6 +214,7 @@ async def get_extensions_response(
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
threads=threads_response,
)

def find_relevant_room_ids_for_extension(
Expand Down Expand Up @@ -970,3 +985,65 @@ async def get_thread_subscriptions_extension_response(
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)

async def get_threads_extension_response(
self,
sync_config: SlidingSyncConfig,
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension,
to_token: StreamToken,
from_token: Optional[SlidingSyncStreamToken],
) -> Optional[SlidingSyncResult.Extensions.ThreadsExtension]:
"""Handle Threads extension (MSC4360)

Args:
sync_config: Sync configuration
threads_request: The threads extension from the request
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.

Returns:
the response (None if empty or threads extension is disabled)
"""
if not threads_request.enabled:
return None

limit = threads_request.limit

# TODO: is the `room_key` the right thing to use here?
# ie. does it translate into /relations

updates, prev_batch = await self.store.get_thread_updates_for_user(
user_id=sync_config.user.to_string(),
from_token=from_token.stream_token if from_token else None,
to_token=to_token,
limit=limit,
include_thread_roots=threads_request.include_roots,
)

if len(updates) == 0:
return None

# Collect thread root events and get bundled aggregations
thread_root_events = [event for _, _, event in updates if event]
aggregations_map = {}
if thread_root_events:
aggregations_map = await self.relations_handler.get_bundled_aggregations(
thread_root_events,
sync_config.user.to_string(),
)

thread_updates: Dict[str, Dict[str, _ThreadUpdate]] = {}
for thread_root_id, room_id, thread_root_event in updates:
bundled_aggs = (
aggregations_map.get(thread_root_id) if thread_root_event else None
)
thread_updates.setdefault(room_id, {})[thread_root_id] = _ThreadUpdate(
thread_root=thread_root_event,
prev_batch=None,
bundled_aggregations=bundled_aggs,
)

return SlidingSyncResult.Extensions.ThreadsExtension(
updates=thread_updates,
prev_batch=prev_batch,
)
94 changes: 88 additions & 6 deletions synapse/rest/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from synapse.api.presence import UserPresenceState
from synapse.api.ratelimiting import Ratelimiter
from synapse.events.utils import (
EventClientSerializer,
SerializeEventConfig,
format_event_for_client_v2_without_room_id,
format_event_raw,
Expand Down Expand Up @@ -648,6 +649,7 @@ class SlidingSyncRestServlet(RestServlet):
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
- threads (MSC4360)

Request query parameters:
timeout: How long to wait for new events in milliseconds.
Expand Down Expand Up @@ -851,7 +853,10 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
logger.info("Client has disconnected; not serializing response.")
return 200, {}

response_content = await self.encode_response(requester, sliding_sync_results)
time_now = self.clock.time_msec()
response_content = await self.encode_response(
requester, sliding_sync_results, time_now
)

return 200, response_content

Expand All @@ -860,6 +865,7 @@ async def encode_response(
self,
requester: Requester,
sliding_sync_result: SlidingSyncResult,
time_now: int,
) -> JsonDict:
response: JsonDict = defaultdict(dict)

Expand All @@ -868,10 +874,10 @@ async def encode_response(
if serialized_lists:
response["lists"] = serialized_lists
response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms
requester, sliding_sync_result.rooms, time_now
)
response["extensions"] = await self.encode_extensions(
requester, sliding_sync_result.extensions
requester, sliding_sync_result.extensions, time_now
)

return response
Expand Down Expand Up @@ -903,9 +909,8 @@ async def encode_rooms(
self,
requester: Requester,
rooms: Dict[str, SlidingSyncResult.RoomResult],
time_now: int,
) -> JsonDict:
time_now = self.clock.time_msec()

serialize_options = SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id,
requester=requester,
Expand Down Expand Up @@ -1021,7 +1026,10 @@ async def encode_rooms(

@trace_with_opname("sliding_sync.encode_extensions")
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
self,
requester: Requester,
extensions: SlidingSyncResult.Extensions,
time_now: int,
) -> JsonDict:
serialized_extensions: JsonDict = {}

Expand Down Expand Up @@ -1091,6 +1099,16 @@ async def encode_extensions(
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)

# excludes both None and falsy `threads`
if extensions.threads:
serialized_extensions[
"io.element.msc4360.threads"
] = await _serialise_threads(
self.event_serializer,
time_now,
extensions.threads,
)

return serialized_extensions


Expand Down Expand Up @@ -1127,6 +1145,70 @@ def _serialise_thread_subscriptions(
return out


async def _serialise_threads(
event_serializer: EventClientSerializer,
time_now: int,
threads: SlidingSyncResult.Extensions.ThreadsExtension,
) -> JsonDict:
"""
Serialize the threads extension response for sliding sync.

Args:
event_serializer: The event serializer to use for serializing thread root events.
time_now: The current time in milliseconds, used for event serialization.
threads: The threads extension data containing thread updates and pagination tokens.

Returns:
A JSON-serializable dict containing:
- "updates": A nested dict mapping room_id -> thread_root_id -> thread update.
Each thread update may contain:
- "thread_root": The serialized thread root event (if include_roots was True),
with bundled aggregations including the latest_event in unsigned.m.relations.m.thread.
- "prev_batch": A pagination token for fetching older events in the thread.
- "prev_batch": A pagination token for fetching older thread updates (if available).
"""
out: JsonDict = {}

if threads.updates:
updates_dict: JsonDict = {}
for room_id, thread_updates in threads.updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in thread_updates.items():
# Serialize the update
update_dict: JsonDict = {}

# Serialize the thread_root event if present
if update.thread_root is not None:
# Create a mapping of event_id to bundled_aggregations
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]

# Add prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = str(update.prev_batch)

room_updates[thread_root_id] = update_dict

updates_dict[room_id] = room_updates

out["updates"] = updates_dict

if threads.prev_batch:
out["prev_batch"] = str(threads.prev_batch)

return out


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

Expand Down
Loading
Loading