9
9
import sys
10
10
import json
11
11
from abc import ABC , abstractmethod
12
+ import threading
12
13
import logging
13
14
14
15
from typing import Optional , Any , Set , Tuple , List , Dict
23
24
from base64 import b64decode
24
25
from time import time
25
26
import aiohttp
27
+ import asyncio
26
28
29
+ from aiohttp .client_exceptions import ClientConnectorError , ClientResponseError , ClientPayloadError
27
30
from cryptography .hazmat .primitives .ciphers import Cipher , algorithms , modes
28
31
from cryptography .hazmat .primitives import padding
29
32
from urllib3 import PoolManager
@@ -805,10 +808,151 @@ def destroy(self) -> None:
805
808
self .docs .clear ()
806
809
807
810
811
+ class SSEClient :
812
+ def __init__ (self , api_host , client_key , on_event , reconnect_delay = 5 , headers = None ):
813
+ self .api_host = api_host
814
+ self .client_key = client_key
815
+
816
+ self .on_event = on_event
817
+ self .on_connect = None
818
+ self .on_disconnect = None
819
+ self .reconnect_delay = reconnect_delay
820
+
821
+ self ._sse_session = None
822
+ self ._sse_thread = None
823
+ self ._loop = None
824
+
825
+ self .is_running = False
826
+
827
+ self .headers = {
828
+ "Accept" : "application/json; q=0.5, text/event-stream" ,
829
+ "Cache-Control" : "no-cache" ,
830
+ }
831
+
832
+ if headers :
833
+ self .headers .update (headers )
834
+
835
+ def connect (self ):
836
+ if self .is_running :
837
+ logger .debug ("Streaming session is already running." )
838
+ return
839
+
840
+ self .is_running = True
841
+ self ._sse_thread = threading .Thread (target = self ._run_sse_channel )
842
+ self ._sse_thread .start ()
843
+
844
+ def disconnect (self ):
845
+ self .is_running = False
846
+ if self ._loop and self ._loop .is_running ():
847
+ future = asyncio .run_coroutine_threadsafe (self ._stop_session (), self ._loop )
848
+ try :
849
+ future .result ()
850
+ except Exception as e :
851
+ logger .error (f"Streaming disconnect error: { e } " )
852
+
853
+ if self ._sse_thread :
854
+ self ._sse_thread .join (timeout = 5 )
855
+
856
+ logger .debug ("Streaming session disconnected" )
857
+
858
+ def onConnect (self , cb ):
859
+ self .on_connect = cb
860
+
861
+ def onDisconnect (self , cb ):
862
+ self .on_disconnect = cb
863
+
864
+ def _get_sse_url (self , api_host : str , client_key : str ) -> str :
865
+ api_host = (api_host or "https://cdn.growthbook.io" ).rstrip ("/" )
866
+ return f"{ api_host } /sub/{ client_key } "
867
+
868
+ async def _init_session (self ):
869
+ url = self ._get_sse_url (self .api_host , self .client_key )
870
+
871
+ while self .is_running :
872
+ try :
873
+ async with aiohttp .ClientSession (headers = self .headers ) as session :
874
+ self ._sse_session = session
875
+ if self .on_connect :
876
+ self .on_connect ()
877
+
878
+ async with session .get (url ) as response :
879
+ response .raise_for_status ()
880
+ await self ._process_response (response )
881
+ except ClientResponseError as e :
882
+ logger .error (f"Streaming error, closing connection: { e .status } { e .message } " )
883
+ self .is_running = False
884
+ break
885
+ except (ClientConnectorError , ClientPayloadError ) as e :
886
+ logger .error (f"Streaming error: { e } " )
887
+ if not self .is_running :
888
+ break
889
+ await self ._wait_for_reconnect ()
890
+ except TimeoutError :
891
+ logger .warning (f"Streaming connection timed out after { self .timeout } seconds." )
892
+ await self ._wait_for_reconnect ()
893
+ except asyncio .CancelledError :
894
+ logger .debug ("Streaming was cancelled." )
895
+ break
896
+ finally :
897
+ if self .on_disconnect :
898
+ self .on_disconnect ()
899
+ await self ._close_session ()
900
+
901
+ async def _process_response (self , response ):
902
+ event_data = {}
903
+ async for line in response .content :
904
+ decoded_line = line .decode ('utf-8' ).strip ()
905
+ if decoded_line .startswith ("event:" ):
906
+ event_data ['type' ] = decoded_line [len ("event:" ):].strip ()
907
+ elif decoded_line .startswith ("data:" ):
908
+ event_data ['data' ] = event_data .get ('data' , '' ) + f"\n { decoded_line [len ('data:' ):].strip ()} "
909
+ elif not decoded_line :
910
+ if 'type' in event_data and 'data' in event_data :
911
+ self .on_event (event_data )
912
+ event_data = {}
913
+
914
+ if 'type' in event_data and 'data' in event_data :
915
+ self .on_event (event_data )
916
+
917
+ async def _wait_for_reconnect (self ):
918
+ logger .debug (f"Attempting to reconnect streaming in { self .reconnect_delay } " )
919
+ await asyncio .sleep (self .reconnect_delay )
920
+
921
+ async def _close_session (self ):
922
+ if self ._sse_session :
923
+ await self ._sse_session .close ()
924
+ logger .debug ("Streaming session closed." )
925
+
926
+ def _run_sse_channel (self ):
927
+ self ._loop = asyncio .new_event_loop ()
928
+ # asyncio.set_event_loop(self._loop)
929
+
930
+ try :
931
+ self ._loop .run_until_complete (self ._init_session ())
932
+ except asyncio .CancelledError :
933
+ pass
934
+ finally :
935
+ self ._loop .run_until_complete (self ._loop .shutdown_asyncgens ())
936
+ self ._loop .close ()
937
+
938
+ async def _stop_session (self ):
939
+ if self ._sse_session :
940
+ await self ._sse_session .close ()
941
+
942
+ if self ._loop and self ._loop .is_running ():
943
+ tasks = [task for task in asyncio .all_tasks (self ._loop ) if not task .done ()]
944
+ for task in tasks :
945
+ task .cancel ()
946
+ try :
947
+ await task
948
+ except asyncio .CancelledError :
949
+ pass
950
+
808
951
class FeatureRepository (object ):
809
952
def __init__ (self ) -> None :
810
953
self .cache : AbstractFeatureCache = InMemoryFeatureCache ()
811
954
self .http : Optional [PoolManager ] = None
955
+ self .sse_client : SSEClient = None
812
956
813
957
def set_cache (self , cache : AbstractFeatureCache ) -> None :
814
958
self .cache = cache
@@ -930,6 +1074,21 @@ async def _fetch_features_async(
930
1074
logger .warning ("GrowthBook API response missing features" )
931
1075
return None
932
1076
1077
+ def onConnect (self ):
1078
+ logger .debug ('Streaming session established' )
1079
+
1080
+ def onDisconnect (self ):
1081
+ logger .debug ('Streaming session disconnected' )
1082
+
1083
+ def startAutoRefresh (self , api_host , client_key , cb ):
1084
+ self .sse_client = self .sse_client or SSEClient (api_host = api_host , client_key = client_key , on_event = cb )
1085
+ self .sse_client .onConnect (self .onConnect )
1086
+ self .sse_client .onDisconnect (self .onDisconnect )
1087
+ self .sse_client .connect ()
1088
+
1089
+ def stopAutoRefresh (self ):
1090
+ self .sse_client .disconnect ()
1091
+
933
1092
@staticmethod
934
1093
def _get_features_url (api_host : str , client_key : str ) -> str :
935
1094
api_host = (api_host or "https://cdn.growthbook.io" ).rstrip ("/" )
@@ -939,7 +1098,6 @@ def _get_features_url(api_host: str, client_key: str) -> str:
939
1098
# Singleton instance
940
1099
feature_repo = FeatureRepository ()
941
1100
942
-
943
1101
class GrowthBook (object ):
944
1102
def __init__ (
945
1103
self ,
@@ -956,6 +1114,7 @@ def __init__(
956
1114
forced_variations : dict = {},
957
1115
sticky_bucket_service : AbstractStickyBucketService = None ,
958
1116
sticky_bucket_identifier_attributes : List [str ] = None ,
1117
+ streaming : bool = False ,
959
1118
# Deprecated args
960
1119
trackingCallback = None ,
961
1120
qaMode : bool = False ,
@@ -981,6 +1140,8 @@ def __init__(
981
1140
self ._qaMode = qa_mode or qaMode
982
1141
self ._trackingCallback = on_experiment_viewed or trackingCallback
983
1142
1143
+ self ._streaming = streaming
1144
+
984
1145
# Deprecated args
985
1146
self ._user = user
986
1147
self ._groups = groups
@@ -994,6 +1155,10 @@ def __init__(
994
1155
if features :
995
1156
self .setFeatures (features )
996
1157
1158
+ if self ._streaming :
1159
+ self .load_features ()
1160
+ self .startAutoRefresh ()
1161
+
997
1162
def load_features (self ) -> None :
998
1163
if not self ._client_key :
999
1164
raise ValueError ("Must specify `client_key` to refresh features" )
@@ -1014,6 +1179,49 @@ async def load_features_async(self) -> None:
1014
1179
if features is not None :
1015
1180
self .setFeatures (features )
1016
1181
1182
+ def features_event_handler (self , features ):
1183
+ decoded = json .loads (features )
1184
+ if not decoded :
1185
+ return None
1186
+
1187
+ if "encryptedFeatures" in decoded :
1188
+ if not self ._decryption_key :
1189
+ raise ValueError ("Must specify decryption_key" )
1190
+ try :
1191
+ decrypted = decrypt (decoded ["encryptedFeatures" ], self ._decryption_key )
1192
+ return json .loads (decrypted )
1193
+ except Exception :
1194
+ logger .warning (
1195
+ "Failed to decrypt features from GrowthBook API response"
1196
+ )
1197
+ return None
1198
+ elif "features" in decoded :
1199
+ self .set_features (decoded ["features" ])
1200
+ else :
1201
+ logger .warning ("GrowthBook API response missing features" )
1202
+
1203
+ def dispatch_sse_event (self , event_data ):
1204
+ event_type = event_data ['type' ]
1205
+ data = event_data ['data' ]
1206
+ if event_type == 'features-updated' :
1207
+ self .load_features ()
1208
+ elif event_type == 'features' :
1209
+ self .features_event_handler (data )
1210
+
1211
+
1212
+ def startAutoRefresh (self ):
1213
+ if not self ._client_key :
1214
+ raise ValueError ("Must specify `client_key` to start features streaming" )
1215
+
1216
+ feature_repo .startAutoRefresh (
1217
+ api_host = self ._api_host ,
1218
+ client_key = self ._client_key ,
1219
+ cb = self .dispatch_sse_event
1220
+ )
1221
+
1222
+ def stopAutoRefresh (self ):
1223
+ feature_repo .stopAutoRefresh ()
1224
+
1017
1225
# @deprecated, use set_features
1018
1226
def setFeatures (self , features : dict ) -> None :
1019
1227
return self .set_features (features )
0 commit comments