Skip to content

Commit b6676e7

Browse files
committed
Added SSE client
1 parent cd1c4f8 commit b6676e7

File tree

1 file changed

+209
-1
lines changed

1 file changed

+209
-1
lines changed

growthbook/growthbook.py

Lines changed: 209 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import json
1111
from abc import ABC, abstractmethod
12+
import threading
1213
import logging
1314

1415
from typing import Optional, Any, Set, Tuple, List, Dict
@@ -23,7 +24,9 @@
2324
from base64 import b64decode
2425
from time import time
2526
import aiohttp
27+
import asyncio
2628

29+
from aiohttp.client_exceptions import ClientConnectorError, ClientResponseError, ClientPayloadError
2730
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
2831
from cryptography.hazmat.primitives import padding
2932
from urllib3 import PoolManager
@@ -805,10 +808,151 @@ def destroy(self) -> None:
805808
self.docs.clear()
806809

807810

811+
class SSEClient:
812+
def __init__(self, api_host, client_key, on_event, reconnect_delay=5, headers=None):
813+
self.api_host = api_host
814+
self.client_key = client_key
815+
816+
self.on_event = on_event
817+
self.on_connect = None
818+
self.on_disconnect = None
819+
self.reconnect_delay = reconnect_delay
820+
821+
self._sse_session = None
822+
self._sse_thread = None
823+
self._loop = None
824+
825+
self.is_running = False
826+
827+
self.headers = {
828+
"Accept": "application/json; q=0.5, text/event-stream",
829+
"Cache-Control": "no-cache",
830+
}
831+
832+
if headers:
833+
self.headers.update(headers)
834+
835+
def connect(self):
836+
if self.is_running:
837+
logger.debug("Streaming session is already running.")
838+
return
839+
840+
self.is_running = True
841+
self._sse_thread = threading.Thread(target=self._run_sse_channel)
842+
self._sse_thread.start()
843+
844+
def disconnect(self):
845+
self.is_running = False
846+
if self._loop and self._loop.is_running():
847+
future = asyncio.run_coroutine_threadsafe(self._stop_session(), self._loop)
848+
try:
849+
future.result()
850+
except Exception as e:
851+
logger.error(f"Streaming disconnect error: {e}")
852+
853+
if self._sse_thread:
854+
self._sse_thread.join(timeout=5)
855+
856+
logger.debug("Streaming session disconnected")
857+
858+
def onConnect(self, cb):
859+
self.on_connect = cb
860+
861+
def onDisconnect(self, cb):
862+
self.on_disconnect = cb
863+
864+
def _get_sse_url(self, api_host: str, client_key: str) -> str:
865+
api_host = (api_host or "https://cdn.growthbook.io").rstrip("/")
866+
return f"{api_host}/sub/{client_key}"
867+
868+
async def _init_session(self):
869+
url = self._get_sse_url(self.api_host, self.client_key)
870+
871+
while self.is_running:
872+
try:
873+
async with aiohttp.ClientSession(headers=self.headers) as session:
874+
self._sse_session = session
875+
if self.on_connect:
876+
self.on_connect()
877+
878+
async with session.get(url) as response:
879+
response.raise_for_status()
880+
await self._process_response(response)
881+
except ClientResponseError as e:
882+
logger.error(f"Streaming error, closing connection: {e.status} {e.message}")
883+
self.is_running = False
884+
break
885+
except (ClientConnectorError, ClientPayloadError) as e:
886+
logger.error(f"Streaming error: {e}")
887+
if not self.is_running:
888+
break
889+
await self._wait_for_reconnect()
890+
except TimeoutError:
891+
logger.warning(f"Streaming connection timed out after {self.timeout} seconds.")
892+
await self._wait_for_reconnect()
893+
except asyncio.CancelledError:
894+
logger.debug("Streaming was cancelled.")
895+
break
896+
finally:
897+
if self.on_disconnect:
898+
self.on_disconnect()
899+
await self._close_session()
900+
901+
async def _process_response(self, response):
902+
event_data = {}
903+
async for line in response.content:
904+
decoded_line = line.decode('utf-8').strip()
905+
if decoded_line.startswith("event:"):
906+
event_data['type'] = decoded_line[len("event:"):].strip()
907+
elif decoded_line.startswith("data:"):
908+
event_data['data'] = event_data.get('data', '') + f"\n{decoded_line[len('data:'):].strip()}"
909+
elif not decoded_line:
910+
if 'type' in event_data and 'data' in event_data:
911+
self.on_event(event_data)
912+
event_data = {}
913+
914+
if 'type' in event_data and 'data' in event_data:
915+
self.on_event(event_data)
916+
917+
async def _wait_for_reconnect(self):
918+
logger.debug(f"Attempting to reconnect streaming in {self.reconnect_delay}")
919+
await asyncio.sleep(self.reconnect_delay)
920+
921+
async def _close_session(self):
922+
if self._sse_session:
923+
await self._sse_session.close()
924+
logger.debug("Streaming session closed.")
925+
926+
def _run_sse_channel(self):
927+
self._loop = asyncio.new_event_loop()
928+
# asyncio.set_event_loop(self._loop)
929+
930+
try:
931+
self._loop.run_until_complete(self._init_session())
932+
except asyncio.CancelledError:
933+
pass
934+
finally:
935+
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
936+
self._loop.close()
937+
938+
async def _stop_session(self):
939+
if self._sse_session:
940+
await self._sse_session.close()
941+
942+
if self._loop and self._loop.is_running():
943+
tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()]
944+
for task in tasks:
945+
task.cancel()
946+
try:
947+
await task
948+
except asyncio.CancelledError:
949+
pass
950+
808951
class FeatureRepository(object):
809952
def __init__(self) -> None:
810953
self.cache: AbstractFeatureCache = InMemoryFeatureCache()
811954
self.http: Optional[PoolManager] = None
955+
self.sse_client: SSEClient = None
812956

813957
def set_cache(self, cache: AbstractFeatureCache) -> None:
814958
self.cache = cache
@@ -930,6 +1074,21 @@ async def _fetch_features_async(
9301074
logger.warning("GrowthBook API response missing features")
9311075
return None
9321076

1077+
def onConnect(self):
1078+
logger.debug('Streaming session established')
1079+
1080+
def onDisconnect(self):
1081+
logger.debug('Streaming session disconnected')
1082+
1083+
def startAutoRefresh(self, api_host, client_key, cb):
1084+
self.sse_client = self.sse_client or SSEClient(api_host=api_host, client_key=client_key, on_event=cb)
1085+
self.sse_client.onConnect(self.onConnect)
1086+
self.sse_client.onDisconnect(self.onDisconnect)
1087+
self.sse_client.connect()
1088+
1089+
def stopAutoRefresh(self):
1090+
self.sse_client.disconnect()
1091+
9331092
@staticmethod
9341093
def _get_features_url(api_host: str, client_key: str) -> str:
9351094
api_host = (api_host or "https://cdn.growthbook.io").rstrip("/")
@@ -939,7 +1098,6 @@ def _get_features_url(api_host: str, client_key: str) -> str:
9391098
# Singleton instance
9401099
feature_repo = FeatureRepository()
9411100

942-
9431101
class GrowthBook(object):
9441102
def __init__(
9451103
self,
@@ -956,6 +1114,7 @@ def __init__(
9561114
forced_variations: dict = {},
9571115
sticky_bucket_service: AbstractStickyBucketService = None,
9581116
sticky_bucket_identifier_attributes: List[str] = None,
1117+
streaming: bool = False,
9591118
# Deprecated args
9601119
trackingCallback=None,
9611120
qaMode: bool = False,
@@ -981,6 +1140,8 @@ def __init__(
9811140
self._qaMode = qa_mode or qaMode
9821141
self._trackingCallback = on_experiment_viewed or trackingCallback
9831142

1143+
self._streaming = streaming
1144+
9841145
# Deprecated args
9851146
self._user = user
9861147
self._groups = groups
@@ -994,6 +1155,10 @@ def __init__(
9941155
if features:
9951156
self.setFeatures(features)
9961157

1158+
if self._streaming:
1159+
self.load_features()
1160+
self.startAutoRefresh()
1161+
9971162
def load_features(self) -> None:
9981163
if not self._client_key:
9991164
raise ValueError("Must specify `client_key` to refresh features")
@@ -1014,6 +1179,49 @@ async def load_features_async(self) -> None:
10141179
if features is not None:
10151180
self.setFeatures(features)
10161181

1182+
def features_event_handler(self, features):
1183+
decoded = json.loads(features)
1184+
if not decoded:
1185+
return None
1186+
1187+
if "encryptedFeatures" in decoded:
1188+
if not self._decryption_key:
1189+
raise ValueError("Must specify decryption_key")
1190+
try:
1191+
decrypted = decrypt(decoded["encryptedFeatures"], self._decryption_key)
1192+
return json.loads(decrypted)
1193+
except Exception:
1194+
logger.warning(
1195+
"Failed to decrypt features from GrowthBook API response"
1196+
)
1197+
return None
1198+
elif "features" in decoded:
1199+
self.set_features(decoded["features"])
1200+
else:
1201+
logger.warning("GrowthBook API response missing features")
1202+
1203+
def dispatch_sse_event(self, event_data):
1204+
event_type = event_data['type']
1205+
data = event_data['data']
1206+
if event_type == 'features-updated':
1207+
self.load_features()
1208+
elif event_type == 'features':
1209+
self.features_event_handler(data)
1210+
1211+
1212+
def startAutoRefresh(self):
1213+
if not self._client_key:
1214+
raise ValueError("Must specify `client_key` to start features streaming")
1215+
1216+
feature_repo.startAutoRefresh(
1217+
api_host=self._api_host,
1218+
client_key=self._client_key,
1219+
cb=self.dispatch_sse_event
1220+
)
1221+
1222+
def stopAutoRefresh(self):
1223+
feature_repo.stopAutoRefresh()
1224+
10171225
# @deprecated, use set_features
10181226
def setFeatures(self, features: dict) -> None:
10191227
return self.set_features(features)

0 commit comments

Comments
 (0)