Skip to content

Commit 015ab58

Browse files
authored
Merge pull request #23 from growthbook/feature/streaming
Added SSE client
2 parents cd1c4f8 + fa5431d commit 015ab58

File tree

1 file changed

+189
-1
lines changed

1 file changed

+189
-1
lines changed

growthbook/growthbook.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import json
1111
from abc import ABC, abstractmethod
12+
import threading
1213
import logging
1314

1415
from typing import Optional, Any, Set, Tuple, List, Dict
@@ -23,7 +24,9 @@
2324
from base64 import b64decode
2425
from time import time
2526
import aiohttp
27+
import asyncio
2628

29+
from aiohttp.client_exceptions import ClientConnectorError, ClientResponseError, ClientPayloadError
2730
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
2831
from cryptography.hazmat.primitives import padding
2932
from urllib3 import PoolManager
@@ -805,10 +808,138 @@ def destroy(self) -> None:
805808
self.docs.clear()
806809

807810

811+
class SSEClient:
812+
def __init__(self, api_host, client_key, on_event, reconnect_delay=5, headers=None):
813+
self.api_host = api_host
814+
self.client_key = client_key
815+
816+
self.on_event = on_event
817+
self.reconnect_delay = reconnect_delay
818+
819+
self._sse_session = None
820+
self._sse_thread = None
821+
self._loop = None
822+
823+
self.is_running = False
824+
825+
self.headers = {
826+
"Accept": "application/json; q=0.5, text/event-stream",
827+
"Cache-Control": "no-cache",
828+
}
829+
830+
if headers:
831+
self.headers.update(headers)
832+
833+
def connect(self):
834+
if self.is_running:
835+
logger.debug("Streaming session is already running.")
836+
return
837+
838+
self.is_running = True
839+
self._sse_thread = threading.Thread(target=self._run_sse_channel)
840+
self._sse_thread.start()
841+
842+
def disconnect(self):
843+
self.is_running = False
844+
if self._loop and self._loop.is_running():
845+
future = asyncio.run_coroutine_threadsafe(self._stop_session(), self._loop)
846+
try:
847+
future.result()
848+
except Exception as e:
849+
logger.error(f"Streaming disconnect error: {e}")
850+
851+
if self._sse_thread:
852+
self._sse_thread.join(timeout=5)
853+
854+
logger.debug("Streaming session disconnected")
855+
856+
def _get_sse_url(self, api_host: str, client_key: str) -> str:
857+
api_host = (api_host or "https://cdn.growthbook.io").rstrip("/")
858+
return f"{api_host}/sub/{client_key}"
859+
860+
async def _init_session(self):
861+
url = self._get_sse_url(self.api_host, self.client_key)
862+
863+
while self.is_running:
864+
try:
865+
async with aiohttp.ClientSession(headers=self.headers) as session:
866+
self._sse_session = session
867+
868+
async with session.get(url) as response:
869+
response.raise_for_status()
870+
await self._process_response(response)
871+
except ClientResponseError as e:
872+
logger.error(f"Streaming error, closing connection: {e.status} {e.message}")
873+
self.is_running = False
874+
break
875+
except (ClientConnectorError, ClientPayloadError) as e:
876+
logger.error(f"Streaming error: {e}")
877+
if not self.is_running:
878+
break
879+
await self._wait_for_reconnect()
880+
except TimeoutError:
881+
logger.warning(f"Streaming connection timed out after {self.timeout} seconds.")
882+
await self._wait_for_reconnect()
883+
except asyncio.CancelledError:
884+
logger.debug("Streaming was cancelled.")
885+
break
886+
finally:
887+
await self._close_session()
888+
889+
async def _process_response(self, response):
890+
event_data = {}
891+
async for line in response.content:
892+
decoded_line = line.decode('utf-8').strip()
893+
if decoded_line.startswith("event:"):
894+
event_data['type'] = decoded_line[len("event:"):].strip()
895+
elif decoded_line.startswith("data:"):
896+
event_data['data'] = event_data.get('data', '') + f"\n{decoded_line[len('data:'):].strip()}"
897+
elif not decoded_line:
898+
if 'type' in event_data and 'data' in event_data:
899+
self.on_event(event_data)
900+
event_data = {}
901+
902+
if 'type' in event_data and 'data' in event_data:
903+
self.on_event(event_data)
904+
905+
async def _wait_for_reconnect(self):
906+
logger.debug(f"Attempting to reconnect streaming in {self.reconnect_delay}")
907+
await asyncio.sleep(self.reconnect_delay)
908+
909+
async def _close_session(self):
910+
if self._sse_session:
911+
await self._sse_session.close()
912+
logger.debug("Streaming session closed.")
913+
914+
def _run_sse_channel(self):
915+
self._loop = asyncio.new_event_loop()
916+
917+
try:
918+
self._loop.run_until_complete(self._init_session())
919+
except asyncio.CancelledError:
920+
pass
921+
finally:
922+
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
923+
self._loop.close()
924+
925+
async def _stop_session(self):
926+
if self._sse_session:
927+
await self._sse_session.close()
928+
929+
if self._loop and self._loop.is_running():
930+
tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()]
931+
for task in tasks:
932+
task.cancel()
933+
try:
934+
await task
935+
except asyncio.CancelledError:
936+
pass
937+
808938
class FeatureRepository(object):
809939
def __init__(self) -> None:
810940
self.cache: AbstractFeatureCache = InMemoryFeatureCache()
811941
self.http: Optional[PoolManager] = None
942+
self.sse_client: Optional[SSEClient] = None
812943

813944
def set_cache(self, cache: AbstractFeatureCache) -> None:
814945
self.cache = cache
@@ -930,6 +1061,14 @@ async def _fetch_features_async(
9301061
logger.warning("GrowthBook API response missing features")
9311062
return None
9321063

1064+
1065+
def startAutoRefresh(self, api_host, client_key, cb):
1066+
self.sse_client = self.sse_client or SSEClient(api_host=api_host, client_key=client_key, on_event=cb)
1067+
self.sse_client.connect()
1068+
1069+
def stopAutoRefresh(self):
1070+
self.sse_client.disconnect()
1071+
9331072
@staticmethod
9341073
def _get_features_url(api_host: str, client_key: str) -> str:
9351074
api_host = (api_host or "https://cdn.growthbook.io").rstrip("/")
@@ -939,7 +1078,6 @@ def _get_features_url(api_host: str, client_key: str) -> str:
9391078
# Singleton instance
9401079
feature_repo = FeatureRepository()
9411080

942-
9431081
class GrowthBook(object):
9441082
def __init__(
9451083
self,
@@ -956,6 +1094,7 @@ def __init__(
9561094
forced_variations: dict = {},
9571095
sticky_bucket_service: AbstractStickyBucketService = None,
9581096
sticky_bucket_identifier_attributes: List[str] = None,
1097+
streaming: bool = False,
9591098
# Deprecated args
9601099
trackingCallback=None,
9611100
qaMode: bool = False,
@@ -981,6 +1120,8 @@ def __init__(
9811120
self._qaMode = qa_mode or qaMode
9821121
self._trackingCallback = on_experiment_viewed or trackingCallback
9831122

1123+
self._streaming = streaming
1124+
9841125
# Deprecated args
9851126
self._user = user
9861127
self._groups = groups
@@ -994,6 +1135,10 @@ def __init__(
9941135
if features:
9951136
self.setFeatures(features)
9961137

1138+
if self._streaming:
1139+
self.load_features()
1140+
self.startAutoRefresh()
1141+
9971142
def load_features(self) -> None:
9981143
if not self._client_key:
9991144
raise ValueError("Must specify `client_key` to refresh features")
@@ -1014,6 +1159,49 @@ async def load_features_async(self) -> None:
10141159
if features is not None:
10151160
self.setFeatures(features)
10161161

1162+
def features_event_handler(self, features):
1163+
decoded = json.loads(features)
1164+
if not decoded:
1165+
return None
1166+
1167+
if "encryptedFeatures" in decoded:
1168+
if not self._decryption_key:
1169+
raise ValueError("Must specify decryption_key")
1170+
try:
1171+
decrypted = decrypt(decoded["encryptedFeatures"], self._decryption_key)
1172+
return json.loads(decrypted)
1173+
except Exception:
1174+
logger.warning(
1175+
"Failed to decrypt features from GrowthBook API response"
1176+
)
1177+
return None
1178+
elif "features" in decoded:
1179+
self.set_features(decoded["features"])
1180+
else:
1181+
logger.warning("GrowthBook API response missing features")
1182+
1183+
def dispatch_sse_event(self, event_data):
1184+
event_type = event_data['type']
1185+
data = event_data['data']
1186+
if event_type == 'features-updated':
1187+
self.load_features()
1188+
elif event_type == 'features':
1189+
self.features_event_handler(data)
1190+
1191+
1192+
def startAutoRefresh(self):
1193+
if not self._client_key:
1194+
raise ValueError("Must specify `client_key` to start features streaming")
1195+
1196+
feature_repo.startAutoRefresh(
1197+
api_host=self._api_host,
1198+
client_key=self._client_key,
1199+
cb=self.dispatch_sse_event
1200+
)
1201+
1202+
def stopAutoRefresh(self):
1203+
feature_repo.stopAutoRefresh()
1204+
10171205
# @deprecated, use set_features
10181206
def setFeatures(self, features: dict) -> None:
10191207
return self.set_features(features)

0 commit comments

Comments
 (0)