@@ -430,6 +430,7 @@ async def get_local_media(
430
430
media_id : str ,
431
431
name : Optional [str ],
432
432
max_timeout_ms : int ,
433
+ allow_authenticated : bool = True ,
433
434
federation : bool = False ,
434
435
) -> None :
435
436
"""Responds to requests for local media, if exists, or returns 404.
@@ -442,6 +443,7 @@ async def get_local_media(
442
443
the filename in the Content-Disposition header of the response.
443
444
max_timeout_ms: the maximum number of milliseconds to wait for the
444
445
media to be uploaded.
446
+ allow_authenticated: whether media marked as authenticated may be served to this request
445
447
federation: whether the local media being fetched is for a federation request
446
448
447
449
Returns:
@@ -451,6 +453,10 @@ async def get_local_media(
451
453
if not media_info :
452
454
return
453
455
456
+ if self .hs .config .media .enable_authenticated_media and not allow_authenticated :
457
+ if media_info .authenticated :
458
+ raise NotFoundError ()
459
+
454
460
self .mark_recently_accessed (None , media_id )
455
461
456
462
media_type = media_info .media_type
@@ -481,6 +487,7 @@ async def get_remote_media(
481
487
max_timeout_ms : int ,
482
488
ip_address : str ,
483
489
use_federation_endpoint : bool ,
490
+ allow_authenticated : bool = True ,
484
491
) -> None :
485
492
"""Respond to requests for remote media.
486
493
@@ -495,6 +502,8 @@ async def get_remote_media(
495
502
ip_address: the IP address of the requester
496
503
use_federation_endpoint: whether to request the remote media over the new
497
504
federation `/download` endpoint
505
+ allow_authenticated: whether media marked as authenticated may be served to this
506
+ request
498
507
499
508
Returns:
500
509
Resolves once a response has successfully been written to request
@@ -526,6 +535,7 @@ async def get_remote_media(
526
535
self .download_ratelimiter ,
527
536
ip_address ,
528
537
use_federation_endpoint ,
538
+ allow_authenticated ,
529
539
)
530
540
531
541
# We deliberately stream the file outside the lock
@@ -548,6 +558,7 @@ async def get_remote_media_info(
548
558
max_timeout_ms : int ,
549
559
ip_address : str ,
550
560
use_federation : bool ,
561
+ allow_authenticated : bool ,
551
562
) -> RemoteMedia :
552
563
"""Gets the media info associated with the remote file, downloading
553
564
if necessary.
@@ -560,6 +571,8 @@ async def get_remote_media_info(
560
571
ip_address: IP address of the requester
561
572
use_federation: if a download is necessary, whether to request the remote file
562
573
over the federation `/download` endpoint
574
+ allow_authenticated: whether media marked as authenticated may be served to this
575
+ request
563
576
564
577
Returns:
565
578
The media info of the file
@@ -581,6 +594,7 @@ async def get_remote_media_info(
581
594
self .download_ratelimiter ,
582
595
ip_address ,
583
596
use_federation ,
597
+ allow_authenticated ,
584
598
)
585
599
586
600
# Ensure we actually use the responder so that it releases resources
@@ -598,6 +612,7 @@ async def _get_remote_media_impl(
598
612
download_ratelimiter : Ratelimiter ,
599
613
ip_address : str ,
600
614
use_federation_endpoint : bool ,
615
+ allow_authenticated : bool ,
601
616
) -> Tuple [Optional [Responder ], RemoteMedia ]:
602
617
"""Looks for media in local cache, if not there then attempt to
603
618
download from remote server.
@@ -619,6 +634,11 @@ async def _get_remote_media_impl(
619
634
"""
620
635
media_info = await self .store .get_cached_remote_media (server_name , media_id )
621
636
637
+ if self .hs .config .media .enable_authenticated_media and not allow_authenticated :
638
+ # if it isn't cached then don't fetch it or if it's authenticated then don't serve it
639
+ if not media_info or media_info .authenticated :
640
+ raise NotFoundError ()
641
+
622
642
# file_id is the ID we use to track the file locally. If we've already
623
643
# seen the file then reuse the existing ID, otherwise generate a new
624
644
# one.
@@ -792,6 +812,11 @@ async def _download_remote_file(
792
812
793
813
logger .info ("Stored remote media in file %r" , fname )
794
814
815
+ if self .hs .config .media .enable_authenticated_media :
816
+ authenticated = True
817
+ else :
818
+ authenticated = False
819
+
795
820
return RemoteMedia (
796
821
media_origin = server_name ,
797
822
media_id = media_id ,
@@ -802,6 +827,7 @@ async def _download_remote_file(
802
827
filesystem_id = file_id ,
803
828
last_access_ts = time_now_ms ,
804
829
quarantined_by = None ,
830
+ authenticated = authenticated ,
805
831
)
806
832
807
833
async def _federation_download_remote_file (
@@ -915,6 +941,11 @@ async def _federation_download_remote_file(
915
941
916
942
logger .debug ("Stored remote media in file %r" , fname )
917
943
944
+ if self .hs .config .media .enable_authenticated_media :
945
+ authenticated = True
946
+ else :
947
+ authenticated = False
948
+
918
949
return RemoteMedia (
919
950
media_origin = server_name ,
920
951
media_id = media_id ,
@@ -925,6 +956,7 @@ async def _federation_download_remote_file(
925
956
filesystem_id = file_id ,
926
957
last_access_ts = time_now_ms ,
927
958
quarantined_by = None ,
959
+ authenticated = authenticated ,
928
960
)
929
961
930
962
def _get_thumbnail_requirements (
@@ -1030,7 +1062,12 @@ async def generate_local_exact_thumbnail(
1030
1062
t_len = os .path .getsize (output_path )
1031
1063
1032
1064
await self .store .store_local_thumbnail (
1033
- media_id , t_width , t_height , t_type , t_method , t_len
1065
+ media_id ,
1066
+ t_width ,
1067
+ t_height ,
1068
+ t_type ,
1069
+ t_method ,
1070
+ t_len ,
1034
1071
)
1035
1072
1036
1073
return output_path
0 commit comments