@@ -164,22 +164,24 @@ def __init__(
164
164
prefilled_cache = user_signature_stream_prefill ,
165
165
)
166
166
167
- (
168
- device_list_federation_prefill ,
169
- device_list_federation_list_id ,
170
- ) = self .db_pool .get_cache_dict (
171
- db_conn ,
172
- "device_lists_outbound_pokes" ,
173
- entity_column = "destination" ,
174
- stream_column = "stream_id" ,
175
- max_value = device_list_max ,
176
- limit = 10000 ,
177
- )
178
- self ._device_list_federation_stream_cache = StreamChangeCache (
179
- "DeviceListFederationStreamChangeCache" ,
180
- device_list_federation_list_id ,
181
- prefilled_cache = device_list_federation_prefill ,
182
- )
167
+ self ._device_list_federation_stream_cache = None
168
+ if hs .should_send_federation ():
169
+ (
170
+ device_list_federation_prefill ,
171
+ device_list_federation_list_id ,
172
+ ) = self .db_pool .get_cache_dict (
173
+ db_conn ,
174
+ "device_lists_outbound_pokes" ,
175
+ entity_column = "destination" ,
176
+ stream_column = "stream_id" ,
177
+ max_value = device_list_max ,
178
+ limit = 10000 ,
179
+ )
180
+ self ._device_list_federation_stream_cache = StreamChangeCache (
181
+ "DeviceListFederationStreamChangeCache" ,
182
+ device_list_federation_list_id ,
183
+ prefilled_cache = device_list_federation_prefill ,
184
+ )
183
185
184
186
if hs .config .worker .run_background_tasks :
185
187
self ._clock .looping_call (
@@ -207,23 +209,30 @@ def _invalidate_caches_for_devices(
207
209
) -> None :
208
210
for row in rows :
209
211
if row .is_signature :
210
- self ._user_signature_stream_cache .entity_has_changed (row .entity , token )
212
+ self ._user_signature_stream_cache .entity_has_changed (row .user_id , token )
211
213
continue
212
214
213
215
# The entities are either user IDs (starting with '@') whose devices
214
216
# have changed, or remote servers that we need to tell about
215
217
# changes.
216
- if row .entity .startswith ("@" ):
217
- self ._device_list_stream_cache .entity_has_changed (row .entity , token )
218
- self .get_cached_devices_for_user .invalidate ((row .entity ,))
219
- self ._get_cached_user_device .invalidate ((row .entity ,))
220
- self .get_device_list_last_stream_id_for_remote .invalidate ((row .entity ,))
221
-
222
- else :
223
- self ._device_list_federation_stream_cache .entity_has_changed (
224
- row .entity , token
218
+ if not row .hosts_calculated :
219
+ self ._device_list_stream_cache .entity_has_changed (row .user_id , token )
220
+ self .get_cached_devices_for_user .invalidate ((row .user_id ,))
221
+ self ._get_cached_user_device .invalidate ((row .user_id ,))
222
+ self .get_device_list_last_stream_id_for_remote .invalidate (
223
+ (row .user_id ,)
225
224
)
226
225
226
+ def device_lists_outbound_pokes_have_changed (
227
+ self , destinations : StrCollection , token : int
228
+ ) -> None :
229
+ assert self ._device_list_federation_stream_cache is not None
230
+
231
+ for destination in destinations :
232
+ self ._device_list_federation_stream_cache .entity_has_changed (
233
+ destination , token
234
+ )
235
+
227
236
def device_lists_in_rooms_have_changed (
228
237
self , room_ids : StrCollection , token : int
229
238
) -> None :
@@ -363,6 +372,11 @@ async def get_device_updates_by_remote(
363
372
EDU contents.
364
373
"""
365
374
now_stream_id = self .get_device_stream_token ()
375
+ if from_stream_id == now_stream_id :
376
+ return now_stream_id , []
377
+
378
+ if self ._device_list_federation_stream_cache is None :
379
+ raise Exception ("Func can only be used on federation senders" )
366
380
367
381
has_changed = self ._device_list_federation_stream_cache .has_entity_changed (
368
382
destination , int (from_stream_id )
@@ -1018,10 +1032,10 @@ def _get_all_device_list_changes_for_remotes(
1018
1032
# This query Does The Right Thing where it'll correctly apply the
1019
1033
# bounds to the inner queries.
1020
1034
sql = """
1021
- SELECT stream_id, entity FROM (
1022
- SELECT stream_id, user_id AS entity FROM device_lists_stream
1035
+ SELECT stream_id, user_id, hosts FROM (
1036
+ SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
1023
1037
UNION ALL
1024
- SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
1038
+ SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
1025
1039
) AS e
1026
1040
WHERE ? < stream_id AND stream_id <= ?
1027
1041
ORDER BY stream_id ASC
@@ -1577,6 +1591,14 @@ def get_device_list_changes_in_room_txn(
1577
1591
get_device_list_changes_in_room_txn ,
1578
1592
)
1579
1593
1594
+ async def get_destinations_for_device (self , stream_id : int ) -> StrCollection :
1595
+ return await self .db_pool .simple_select_onecol (
1596
+ table = "device_lists_outbound_pokes" ,
1597
+ keyvalues = {"stream_id" : stream_id },
1598
+ retcol = "destination" ,
1599
+ desc = "get_destinations_for_device" ,
1600
+ )
1601
+
1580
1602
1581
1603
class DeviceBackgroundUpdateStore (SQLBaseStore ):
1582
1604
def __init__ (
@@ -2112,12 +2134,13 @@ def _add_device_outbound_poke_to_stream_txn(
2112
2134
stream_ids : List [int ],
2113
2135
context : Optional [Dict [str , str ]],
2114
2136
) -> None :
2115
- for host in hosts :
2116
- txn .call_after (
2117
- self ._device_list_federation_stream_cache .entity_has_changed ,
2118
- host ,
2119
- stream_ids [- 1 ],
2120
- )
2137
+ if self ._device_list_federation_stream_cache :
2138
+ for host in hosts :
2139
+ txn .call_after (
2140
+ self ._device_list_federation_stream_cache .entity_has_changed ,
2141
+ host ,
2142
+ stream_ids [- 1 ],
2143
+ )
2121
2144
2122
2145
now = self ._clock .time_msec ()
2123
2146
stream_id_iterator = iter (stream_ids )
0 commit comments