9
9
import sys
10
10
import json
11
11
from abc import ABC , abstractmethod
12
+ import threading
12
13
import logging
13
14
14
15
from typing import Optional , Any , Set , Tuple , List , Dict
23
24
from base64 import b64decode
24
25
from time import time
25
26
import aiohttp
27
+ import asyncio
26
28
29
+ from aiohttp .client_exceptions import ClientConnectorError , ClientResponseError , ClientPayloadError
27
30
from cryptography .hazmat .primitives .ciphers import Cipher , algorithms , modes
28
31
from cryptography .hazmat .primitives import padding
29
32
from urllib3 import PoolManager
@@ -805,10 +808,138 @@ def destroy(self) -> None:
805
808
self .docs .clear ()
806
809
807
810
811
+ class SSEClient :
812
+ def __init__ (self , api_host , client_key , on_event , reconnect_delay = 5 , headers = None ):
813
+ self .api_host = api_host
814
+ self .client_key = client_key
815
+
816
+ self .on_event = on_event
817
+ self .reconnect_delay = reconnect_delay
818
+
819
+ self ._sse_session = None
820
+ self ._sse_thread = None
821
+ self ._loop = None
822
+
823
+ self .is_running = False
824
+
825
+ self .headers = {
826
+ "Accept" : "application/json; q=0.5, text/event-stream" ,
827
+ "Cache-Control" : "no-cache" ,
828
+ }
829
+
830
+ if headers :
831
+ self .headers .update (headers )
832
+
833
+ def connect (self ):
834
+ if self .is_running :
835
+ logger .debug ("Streaming session is already running." )
836
+ return
837
+
838
+ self .is_running = True
839
+ self ._sse_thread = threading .Thread (target = self ._run_sse_channel )
840
+ self ._sse_thread .start ()
841
+
842
+ def disconnect (self ):
843
+ self .is_running = False
844
+ if self ._loop and self ._loop .is_running ():
845
+ future = asyncio .run_coroutine_threadsafe (self ._stop_session (), self ._loop )
846
+ try :
847
+ future .result ()
848
+ except Exception as e :
849
+ logger .error (f"Streaming disconnect error: { e } " )
850
+
851
+ if self ._sse_thread :
852
+ self ._sse_thread .join (timeout = 5 )
853
+
854
+ logger .debug ("Streaming session disconnected" )
855
+
856
+ def _get_sse_url (self , api_host : str , client_key : str ) -> str :
857
+ api_host = (api_host or "https://cdn.growthbook.io" ).rstrip ("/" )
858
+ return f"{ api_host } /sub/{ client_key } "
859
+
860
+ async def _init_session (self ):
861
+ url = self ._get_sse_url (self .api_host , self .client_key )
862
+
863
+ while self .is_running :
864
+ try :
865
+ async with aiohttp .ClientSession (headers = self .headers ) as session :
866
+ self ._sse_session = session
867
+
868
+ async with session .get (url ) as response :
869
+ response .raise_for_status ()
870
+ await self ._process_response (response )
871
+ except ClientResponseError as e :
872
+ logger .error (f"Streaming error, closing connection: { e .status } { e .message } " )
873
+ self .is_running = False
874
+ break
875
+ except (ClientConnectorError , ClientPayloadError ) as e :
876
+ logger .error (f"Streaming error: { e } " )
877
+ if not self .is_running :
878
+ break
879
+ await self ._wait_for_reconnect ()
880
+ except TimeoutError :
881
+ logger .warning (f"Streaming connection timed out after { self .timeout } seconds." )
882
+ await self ._wait_for_reconnect ()
883
+ except asyncio .CancelledError :
884
+ logger .debug ("Streaming was cancelled." )
885
+ break
886
+ finally :
887
+ await self ._close_session ()
888
+
889
+ async def _process_response (self , response ):
890
+ event_data = {}
891
+ async for line in response .content :
892
+ decoded_line = line .decode ('utf-8' ).strip ()
893
+ if decoded_line .startswith ("event:" ):
894
+ event_data ['type' ] = decoded_line [len ("event:" ):].strip ()
895
+ elif decoded_line .startswith ("data:" ):
896
+ event_data ['data' ] = event_data .get ('data' , '' ) + f"\n { decoded_line [len ('data:' ):].strip ()} "
897
+ elif not decoded_line :
898
+ if 'type' in event_data and 'data' in event_data :
899
+ self .on_event (event_data )
900
+ event_data = {}
901
+
902
+ if 'type' in event_data and 'data' in event_data :
903
+ self .on_event (event_data )
904
+
905
+ async def _wait_for_reconnect (self ):
906
+ logger .debug (f"Attempting to reconnect streaming in { self .reconnect_delay } " )
907
+ await asyncio .sleep (self .reconnect_delay )
908
+
909
+ async def _close_session (self ):
910
+ if self ._sse_session :
911
+ await self ._sse_session .close ()
912
+ logger .debug ("Streaming session closed." )
913
+
914
+ def _run_sse_channel (self ):
915
+ self ._loop = asyncio .new_event_loop ()
916
+
917
+ try :
918
+ self ._loop .run_until_complete (self ._init_session ())
919
+ except asyncio .CancelledError :
920
+ pass
921
+ finally :
922
+ self ._loop .run_until_complete (self ._loop .shutdown_asyncgens ())
923
+ self ._loop .close ()
924
+
925
+ async def _stop_session (self ):
926
+ if self ._sse_session :
927
+ await self ._sse_session .close ()
928
+
929
+ if self ._loop and self ._loop .is_running ():
930
+ tasks = [task for task in asyncio .all_tasks (self ._loop ) if not task .done ()]
931
+ for task in tasks :
932
+ task .cancel ()
933
+ try :
934
+ await task
935
+ except asyncio .CancelledError :
936
+ pass
937
+
808
938
class FeatureRepository (object ):
809
939
def __init__ (self ) -> None :
810
940
self .cache : AbstractFeatureCache = InMemoryFeatureCache ()
811
941
self .http : Optional [PoolManager ] = None
942
+ self .sse_client : Optional [SSEClient ] = None
812
943
813
944
def set_cache (self , cache : AbstractFeatureCache ) -> None :
814
945
self .cache = cache
@@ -930,6 +1061,14 @@ async def _fetch_features_async(
930
1061
logger .warning ("GrowthBook API response missing features" )
931
1062
return None
932
1063
1064
+
1065
+ def startAutoRefresh (self , api_host , client_key , cb ):
1066
+ self .sse_client = self .sse_client or SSEClient (api_host = api_host , client_key = client_key , on_event = cb )
1067
+ self .sse_client .connect ()
1068
+
1069
+ def stopAutoRefresh (self ):
1070
+ self .sse_client .disconnect ()
1071
+
933
1072
@staticmethod
934
1073
def _get_features_url (api_host : str , client_key : str ) -> str :
935
1074
api_host = (api_host or "https://cdn.growthbook.io" ).rstrip ("/" )
@@ -939,7 +1078,6 @@ def _get_features_url(api_host: str, client_key: str) -> str:
939
1078
# Singleton instance
940
1079
feature_repo = FeatureRepository ()
941
1080
942
-
943
1081
class GrowthBook (object ):
944
1082
def __init__ (
945
1083
self ,
@@ -956,6 +1094,7 @@ def __init__(
956
1094
forced_variations : dict = {},
957
1095
sticky_bucket_service : AbstractStickyBucketService = None ,
958
1096
sticky_bucket_identifier_attributes : List [str ] = None ,
1097
+ streaming : bool = False ,
959
1098
# Deprecated args
960
1099
trackingCallback = None ,
961
1100
qaMode : bool = False ,
@@ -981,6 +1120,8 @@ def __init__(
981
1120
self ._qaMode = qa_mode or qaMode
982
1121
self ._trackingCallback = on_experiment_viewed or trackingCallback
983
1122
1123
+ self ._streaming = streaming
1124
+
984
1125
# Deprecated args
985
1126
self ._user = user
986
1127
self ._groups = groups
@@ -994,6 +1135,10 @@ def __init__(
994
1135
if features :
995
1136
self .setFeatures (features )
996
1137
1138
+ if self ._streaming :
1139
+ self .load_features ()
1140
+ self .startAutoRefresh ()
1141
+
997
1142
def load_features (self ) -> None :
998
1143
if not self ._client_key :
999
1144
raise ValueError ("Must specify `client_key` to refresh features" )
@@ -1014,6 +1159,49 @@ async def load_features_async(self) -> None:
1014
1159
if features is not None :
1015
1160
self .setFeatures (features )
1016
1161
1162
+ def features_event_handler (self , features ):
1163
+ decoded = json .loads (features )
1164
+ if not decoded :
1165
+ return None
1166
+
1167
+ if "encryptedFeatures" in decoded :
1168
+ if not self ._decryption_key :
1169
+ raise ValueError ("Must specify decryption_key" )
1170
+ try :
1171
+ decrypted = decrypt (decoded ["encryptedFeatures" ], self ._decryption_key )
1172
+ return json .loads (decrypted )
1173
+ except Exception :
1174
+ logger .warning (
1175
+ "Failed to decrypt features from GrowthBook API response"
1176
+ )
1177
+ return None
1178
+ elif "features" in decoded :
1179
+ self .set_features (decoded ["features" ])
1180
+ else :
1181
+ logger .warning ("GrowthBook API response missing features" )
1182
+
1183
+ def dispatch_sse_event (self , event_data ):
1184
+ event_type = event_data ['type' ]
1185
+ data = event_data ['data' ]
1186
+ if event_type == 'features-updated' :
1187
+ self .load_features ()
1188
+ elif event_type == 'features' :
1189
+ self .features_event_handler (data )
1190
+
1191
+
1192
+ def startAutoRefresh (self ):
1193
+ if not self ._client_key :
1194
+ raise ValueError ("Must specify `client_key` to start features streaming" )
1195
+
1196
+ feature_repo .startAutoRefresh (
1197
+ api_host = self ._api_host ,
1198
+ client_key = self ._client_key ,
1199
+ cb = self .dispatch_sse_event
1200
+ )
1201
+
1202
+ def stopAutoRefresh (self ):
1203
+ feature_repo .stopAutoRefresh ()
1204
+
1017
1205
# @deprecated, use set_features
1018
1206
def setFeatures (self , features : dict ) -> None :
1019
1207
return self .set_features (features )
0 commit comments