27
27
cast ,
28
28
)
29
29
30
- from typing_extensions import assert_never
30
+ from typing_extensions import TypeAlias , assert_never
31
31
32
32
from synapse .api .constants import AccountDataTypes , EduTypes
33
33
from synapse .handlers .receipts import ReceiptEventSource
40
40
SlidingSyncStreamToken ,
41
41
StrCollection ,
42
42
StreamToken ,
43
+ ThreadSubscriptionsToken ,
43
44
)
44
45
from synapse .types .handlers .sliding_sync import (
45
46
HaveSentRoomFlag ,
54
55
gather_optional_coroutines ,
55
56
)
56
57
58
+ _ThreadSubscription : TypeAlias = (
59
+ SlidingSyncResult .Extensions .ThreadSubscriptionsExtension .ThreadSubscription
60
+ )
61
+ _ThreadUnsubscription : TypeAlias = (
62
+ SlidingSyncResult .Extensions .ThreadSubscriptionsExtension .ThreadUnsubscription
63
+ )
64
+
57
65
if TYPE_CHECKING :
58
66
from synapse .server import HomeServer
59
67
@@ -68,6 +76,7 @@ def __init__(self, hs: "HomeServer"):
68
76
self .event_sources = hs .get_event_sources ()
69
77
self .device_handler = hs .get_device_handler ()
70
78
self .push_rules_handler = hs .get_push_rules_handler ()
79
+ self ._enable_thread_subscriptions = hs .config .experimental .msc4306_enabled
71
80
72
81
@trace
73
82
async def get_extensions_response (
@@ -93,7 +102,7 @@ async def get_extensions_response(
93
102
actual_room_ids: The actual room IDs in the the Sliding Sync response.
94
103
actual_room_response_map: A map of room ID to room results in the the
95
104
Sliding Sync response.
96
- to_token: The point in the stream to sync up to.
105
+ to_token: The latest point in the stream to sync up to.
97
106
from_token: The point in the stream to sync from.
98
107
"""
99
108
@@ -156,18 +165,32 @@ async def get_extensions_response(
156
165
from_token = from_token ,
157
166
)
158
167
168
+ thread_subs_coro = None
169
+ if (
170
+ sync_config .extensions .thread_subscriptions is not None
171
+ and self ._enable_thread_subscriptions
172
+ ):
173
+ thread_subs_coro = self .get_thread_subscriptions_extension_response (
174
+ sync_config = sync_config ,
175
+ thread_subscriptions_request = sync_config .extensions .thread_subscriptions ,
176
+ to_token = to_token ,
177
+ from_token = from_token ,
178
+ )
179
+
159
180
(
160
181
to_device_response ,
161
182
e2ee_response ,
162
183
account_data_response ,
163
184
receipts_response ,
164
185
typing_response ,
186
+ thread_subs_response ,
165
187
) = await gather_optional_coroutines (
166
188
to_device_coro ,
167
189
e2ee_coro ,
168
190
account_data_coro ,
169
191
receipts_coro ,
170
192
typing_coro ,
193
+ thread_subs_coro ,
171
194
)
172
195
173
196
return SlidingSyncResult .Extensions (
@@ -176,6 +199,7 @@ async def get_extensions_response(
176
199
account_data = account_data_response ,
177
200
receipts = receipts_response ,
178
201
typing = typing_response ,
202
+ thread_subscriptions = thread_subs_response ,
179
203
)
180
204
181
205
def find_relevant_room_ids_for_extension (
@@ -877,3 +901,72 @@ async def get_typing_extension_response(
877
901
return SlidingSyncResult .Extensions .TypingExtension (
878
902
room_id_to_typing_map = room_id_to_typing_map ,
879
903
)
904
+
905
+ async def get_thread_subscriptions_extension_response (
906
+ self ,
907
+ sync_config : SlidingSyncConfig ,
908
+ thread_subscriptions_request : SlidingSyncConfig .Extensions .ThreadSubscriptionsExtension ,
909
+ to_token : StreamToken ,
910
+ from_token : Optional [SlidingSyncStreamToken ],
911
+ ) -> Optional [SlidingSyncResult .Extensions .ThreadSubscriptionsExtension ]:
912
+ """Handle Thread Subscriptions extension (MSC4308)
913
+
914
+ Args:
915
+ sync_config: Sync configuration
916
+ thread_subscriptions_request: The thread_subscriptions extension from the request
917
+ to_token: The point in the stream to sync up to.
918
+ from_token: The point in the stream to sync from.
919
+
920
+ Returns:
921
+ the response (None if empty or thread subscriptions are disabled)
922
+ """
923
+ if not thread_subscriptions_request .enabled :
924
+ return None
925
+
926
+ limit = thread_subscriptions_request .limit
927
+
928
+ if from_token :
929
+ from_stream_id = from_token .stream_token .thread_subscriptions_key
930
+ else :
931
+ from_stream_id = StreamToken .START .thread_subscriptions_key
932
+
933
+ to_stream_id = to_token .thread_subscriptions_key
934
+
935
+ updates = await self .store .get_latest_updated_thread_subscriptions_for_user (
936
+ user_id = sync_config .user .to_string (),
937
+ from_id = from_stream_id ,
938
+ to_id = to_stream_id ,
939
+ limit = limit ,
940
+ )
941
+
942
+ if len (updates ) == 0 :
943
+ return None
944
+
945
+ subscribed_threads : Dict [str , Dict [str , _ThreadSubscription ]] = {}
946
+ unsubscribed_threads : Dict [str , Dict [str , _ThreadUnsubscription ]] = {}
947
+ for stream_id , room_id , thread_root_id , subscribed , automatic in updates :
948
+ if subscribed :
949
+ subscribed_threads .setdefault (room_id , {})[thread_root_id ] = (
950
+ _ThreadSubscription (
951
+ automatic = automatic ,
952
+ bump_stamp = stream_id ,
953
+ )
954
+ )
955
+ else :
956
+ unsubscribed_threads .setdefault (room_id , {})[thread_root_id ] = (
957
+ _ThreadUnsubscription (bump_stamp = stream_id )
958
+ )
959
+
960
+ prev_batch = None
961
+ if len (updates ) == limit :
962
+ # Tell the client about a potential gap where there may be more
963
+ # thread subscriptions for it to backpaginate.
964
+ # We subtract one because the 'later in the stream' bound is inclusive,
965
+ # and we already saw the element at index 0.
966
+ prev_batch = ThreadSubscriptionsToken (updates [0 ][0 ] - 1 )
967
+
968
+ return SlidingSyncResult .Extensions .ThreadSubscriptionsExtension (
969
+ subscribed = subscribed_threads ,
970
+ unsubscribed = unsubscribed_threads ,
971
+ prev_batch = prev_batch ,
972
+ )
0 commit comments