@@ -164,24 +164,22 @@ def __init__(
164
164
prefilled_cache = user_signature_stream_prefill ,
165
165
)
166
166
167
- self ._device_list_federation_stream_cache = None
168
- if hs .should_send_federation ():
169
- (
170
- device_list_federation_prefill ,
171
- device_list_federation_list_id ,
172
- ) = self .db_pool .get_cache_dict (
173
- db_conn ,
174
- "device_lists_outbound_pokes" ,
175
- entity_column = "destination" ,
176
- stream_column = "stream_id" ,
177
- max_value = device_list_max ,
178
- limit = 10000 ,
179
- )
180
- self ._device_list_federation_stream_cache = StreamChangeCache (
181
- "DeviceListFederationStreamChangeCache" ,
182
- device_list_federation_list_id ,
183
- prefilled_cache = device_list_federation_prefill ,
184
- )
167
+ (
168
+ device_list_federation_prefill ,
169
+ device_list_federation_list_id ,
170
+ ) = self .db_pool .get_cache_dict (
171
+ db_conn ,
172
+ "device_lists_outbound_pokes" ,
173
+ entity_column = "destination" ,
174
+ stream_column = "stream_id" ,
175
+ max_value = device_list_max ,
176
+ limit = 10000 ,
177
+ )
178
+ self ._device_list_federation_stream_cache = StreamChangeCache (
179
+ "DeviceListFederationStreamChangeCache" ,
180
+ device_list_federation_list_id ,
181
+ prefilled_cache = device_list_federation_prefill ,
182
+ )
185
183
186
184
if hs .config .worker .run_background_tasks :
187
185
self ._clock .looping_call (
@@ -209,29 +207,22 @@ def _invalidate_caches_for_devices(
209
207
) -> None :
210
208
for row in rows :
211
209
if row .is_signature :
212
- self ._user_signature_stream_cache .entity_has_changed (row .user_id , token )
210
+ self ._user_signature_stream_cache .entity_has_changed (row .entity , token )
213
211
continue
214
212
215
213
# The entities are either user IDs (starting with '@') whose devices
216
214
# have changed, or remote servers that we need to tell about
217
215
# changes.
218
- if not row .hosts_calculated :
219
- self ._device_list_stream_cache .entity_has_changed (row .user_id , token )
220
- self .get_cached_devices_for_user .invalidate ((row .user_id ,))
221
- self ._get_cached_user_device .invalidate ((row .user_id ,))
222
- self .get_device_list_last_stream_id_for_remote .invalidate (
223
- (row .user_id ,)
224
- )
216
+ if row .entity .startswith ("@" ):
217
+ self ._device_list_stream_cache .entity_has_changed (row .entity , token )
218
+ self .get_cached_devices_for_user .invalidate ((row .entity ,))
219
+ self ._get_cached_user_device .invalidate ((row .entity ,))
220
+ self .get_device_list_last_stream_id_for_remote .invalidate ((row .entity ,))
225
221
226
- def device_lists_outbound_pokes_have_changed (
227
- self , destinations : StrCollection , token : int
228
- ) -> None :
229
- assert self ._device_list_federation_stream_cache is not None
230
-
231
- for destination in destinations :
232
- self ._device_list_federation_stream_cache .entity_has_changed (
233
- destination , token
234
- )
222
+ else :
223
+ self ._device_list_federation_stream_cache .entity_has_changed (
224
+ row .entity , token
225
+ )
235
226
236
227
def device_lists_in_rooms_have_changed (
237
228
self , room_ids : StrCollection , token : int
@@ -372,11 +363,6 @@ async def get_device_updates_by_remote(
372
363
EDU contents.
373
364
"""
374
365
now_stream_id = self .get_device_stream_token ()
375
- if from_stream_id == now_stream_id :
376
- return now_stream_id , []
377
-
378
- if self ._device_list_federation_stream_cache is None :
379
- raise Exception ("Func can only be used on federation senders" )
380
366
381
367
has_changed = self ._device_list_federation_stream_cache .has_entity_changed (
382
368
destination , int (from_stream_id )
@@ -1032,10 +1018,10 @@ def _get_all_device_list_changes_for_remotes(
1032
1018
# This query Does The Right Thing where it'll correctly apply the
1033
1019
# bounds to the inner queries.
1034
1020
sql = """
1035
- SELECT stream_id, user_id, hosts FROM (
1036
- SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
1021
+ SELECT stream_id, entity FROM (
1022
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
1037
1023
UNION ALL
1038
- SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
1024
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
1039
1025
) AS e
1040
1026
WHERE ? < stream_id AND stream_id <= ?
1041
1027
ORDER BY stream_id ASC
@@ -1591,14 +1577,6 @@ def get_device_list_changes_in_room_txn(
1591
1577
get_device_list_changes_in_room_txn ,
1592
1578
)
1593
1579
1594
- async def get_destinations_for_device (self , stream_id : int ) -> StrCollection :
1595
- return await self .db_pool .simple_select_onecol (
1596
- table = "device_lists_outbound_pokes" ,
1597
- keyvalues = {"stream_id" : stream_id },
1598
- retcol = "destination" ,
1599
- desc = "get_destinations_for_device" ,
1600
- )
1601
-
1602
1580
1603
1581
class DeviceBackgroundUpdateStore (SQLBaseStore ):
1604
1582
def __init__ (
@@ -2134,13 +2112,12 @@ def _add_device_outbound_poke_to_stream_txn(
2134
2112
stream_ids : List [int ],
2135
2113
context : Optional [Dict [str , str ]],
2136
2114
) -> None :
2137
- if self ._device_list_federation_stream_cache :
2138
- for host in hosts :
2139
- txn .call_after (
2140
- self ._device_list_federation_stream_cache .entity_has_changed ,
2141
- host ,
2142
- stream_ids [- 1 ],
2143
- )
2115
+ for host in hosts :
2116
+ txn .call_after (
2117
+ self ._device_list_federation_stream_cache .entity_has_changed ,
2118
+ host ,
2119
+ stream_ids [- 1 ],
2120
+ )
2144
2121
2145
2122
now = self ._clock .time_msec ()
2146
2123
stream_id_iterator = iter (stream_ids )
0 commit comments