diff --git a/growthbook/growthbook.py b/growthbook/growthbook.py index be54520..7b67394 100644 --- a/growthbook/growthbook.py +++ b/growthbook/growthbook.py @@ -5,6 +5,8 @@ More info at https://www.growthbook.io """ +import hashlib +from pathlib import Path import sys import json import threading @@ -13,14 +15,14 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Set, Tuple, List, Dict, Callable -from .common_types import ( EvaluationContext, - Experiment, - FeatureResult, +from .common_types import ( EvaluationContext, + Experiment, + FeatureResult, Feature, - GlobalContext, - Options, - Result, StackContext, - UserContext, + GlobalContext, + Options, + Result, StackContext, + UserContext, AbstractStickyBucketService, FeatureRule ) @@ -62,7 +64,22 @@ def decrypt(encrypted_str: str, key_str: str) -> str: return bytestring.decode("utf-8") + +class CacheEntry(object): + def __init__(self, value: Dict, ttl: int) -> None: + self.value = value + self.ttl = ttl + self.expires = time() + ttl + + def update(self, value: Dict): + self.value = value + self.expires = time() + self.ttl + class AbstractFeatureCache(ABC): + @abstractmethod + def get_all_entries(self) -> Dict[str, CacheEntry]: + pass + @abstractmethod def get(self, key: str) -> Optional[Dict]: pass @@ -74,36 +91,109 @@ def set(self, key: str, value: Dict, ttl: int) -> None: def clear(self) -> None: pass +class AbstractPersistentFeatureCache(AbstractFeatureCache): -class CacheEntry(object): - def __init__(self, value: Dict, ttl: int) -> None: - self.value = value - self.ttl = ttl - self.expires = time() + ttl - - def update(self, value: Dict): - self.value = value - self.expires = time() + self.ttl + @abstractmethod + def update_cache(self, cache: Dict[str, CacheEntry]) -> None: + pass class InMemoryFeatureCache(AbstractFeatureCache): def __init__(self) -> None: - self.cache: Dict[str, CacheEntry] = {} + self._cache: Dict[str, CacheEntry] = {} def get(self, key: str) -> Optional[Dict]: - if key in self.cache: - entry = self.cache[key] + if key in self._cache: + entry = self._cache[key] if entry.expires >= time(): return entry.value return None def set(self, key: str, value: Dict, ttl: int) -> None: - if key in self.cache: - self.cache[key].update(value) - self.cache[key] = CacheEntry(value, ttl) + if key in self._cache: + self._cache[key].update(value) + self._cache[key] = CacheEntry(value, ttl) + + def clear(self) -> None: + self._cache.clear() + + def get_all_entries(self) -> Dict[str, CacheEntry]: + return self._cache + +class FileFeatureCache(AbstractPersistentFeatureCache): + def __init__(self, cache_file: str, base_directory: Optional[str] = None): + self._cache_file = cache_file + self._cache: Dict[str, CacheEntry] = {} + self._base_directory = base_directory + self.load() + + def _get_base_path(self) -> Path: + base_path = Path(self._base_directory or "./GrowthBook-Cache") + base_path.mkdir(parents=True, exist_ok=True) + return base_path + + def get_all_entries(self) -> Dict[str, CacheEntry]: + return self._cache + + def update_cache(self, cache: Dict[str, CacheEntry]) -> None: + try: + cache_path = self._get_base_path() / f"{self._cache_file}" + raw_cache = { + key: { + "value": entry.value, + "expires": entry.expires + } + for key, entry in cache.items() + if entry.expires > time() + } + with open(cache_path, "w") as f: + json.dump(raw_cache, f) + except Exception as e: + logger.warning(f"Failed to update persistent cache: {e}") + + + def load(self): + try: + cache_path = self._get_base_path() / f"{self._cache_file}" + if not cache_path.exists(): + return + with open(cache_path, "r") as f: + raw_cache = json.load(f) + now = time() + for key, entry_data in raw_cache.items(): + self._cache[key] = CacheEntry( + value=entry_data["value"], ttl=entry_data["expires"] - now + ) + except Exception as e: + logger.warning(f"Failed to load persistent cache: {e}") + + def _save_cache(self): + cache_path = self._get_base_path() / f"{self._cache_file}" + raw_cache = { + key: {"value": entry.value, "expires": entry.expires} + for key, entry in self._cache.items() + if entry.expires > time() + } + try: + with open(cache_path, "w") as f: + json.dump(raw_cache, f) + except Exception as e: + logger.warning(f"Failed to save persistent cache: {e}") + + def get(self, key: str) -> Optional[Dict]: + entry = self._cache.get(key) + if entry and entry.expires >= time(): + return entry.value + return None + + def set(self, key: str, value: Dict, ttl: int) -> None: + self._cache[key] = CacheEntry(value, ttl) + self._save_cache() def clear(self) -> None: - self.cache.clear() + self._cache.clear() + self._save_cache() + class InMemoryStickyBucketService(AbstractStickyBucketService): def __init__(self) -> None: @@ -170,7 +260,7 @@ def _get_sse_url(self, api_host: str, client_key: str) -> str: async def _init_session(self): url = self._get_sse_url(self.api_host, self.client_key) - + while self.is_running: try: async with aiohttp.ClientSession(headers=self.headers) as session: @@ -224,7 +314,7 @@ async def _close_session(self): def _run_sse_channel(self): self._loop = asyncio.new_event_loop() - + try: self._loop.run_until_complete(self._init_session()) except asyncio.CancelledError: @@ -246,21 +336,68 @@ async def _stop_session(self): except asyncio.CancelledError: pass +class LayeredFeatureCache(AbstractFeatureCache): + def __init__(self, primary_cache: AbstractFeatureCache, secondary_cache: AbstractPersistentFeatureCache): + self._primary_cache = primary_cache + self._secondary_cache = secondary_cache + self._secondary_cache.load() + self._sync_caches_on_init() + + def _sync_caches_on_init(self): + for key, entry in self._secondary_cache.get_all_entries().items(): + if entry.expires > time(): + self._primary_cache.set(key, entry.value, int(entry.expires - time())) + + def get_all_entries(self) -> Dict[str, CacheEntry]: + return self._primary_cache.get_all_entries() + + def get(self, key: str) -> Optional[Dict]: + value = self._primary_cache.get(key) + if value is not None: + return value + + value = self._secondary_cache.get(key) + if value is not None: + entry = self._secondary_cache.get_all_entries().get(key) + if entry: + self._primary_cache.set(key, value, int(entry.expires - time())) + return value + return None + + def set(self, key: str, value: Dict, ttl: int) -> None: + self._primary_cache.set(key, value, ttl) + self._secondary_cache.set(key, value, ttl) + + def clear(self) -> None: + self._primary_cache.clear() + self._secondary_cache.clear() + + def update_secondary_cache_from_primary(self): + self._secondary_cache.update_cache(self._primary_cache.get_all_entries()) + class FeatureRepository(object): - def __init__(self) -> None: - self.cache: AbstractFeatureCache = InMemoryFeatureCache() + def __init__(self, cache_base_dir: Optional[str] = None) -> None: + self._in_memory = InMemoryFeatureCache() + self._file_cache = FileFeatureCache(cache_file="features_cache.json", base_directory=cache_base_dir) + self.feature_cache = LayeredFeatureCache( + primary_cache=self._in_memory, + secondary_cache=self._file_cache + ) self.http: Optional[PoolManager] = None self.sse_client: Optional[SSEClient] = None self._feature_update_callbacks: List[Callable[[Dict], None]] = [] - def set_cache(self, cache: AbstractFeatureCache) -> None: - self.cache = cache + def load_features_from_persistent_cache(self) -> None: + pass + + def set_persistent_cache(self, cache: AbstractPersistentFeatureCache) -> None: + self.persistent_cache = cache def clear_cache(self): - self.cache.clear() + self.feature_cache.clear() def save_in_cache(self, key: str, res, ttl: int = 600): - self.cache.set(key, res, ttl) + self.feature_cache.set(key, res, ttl) def add_feature_update_callback(self, callback: Callable[[Dict], None]) -> None: """Add a callback to be notified when features are updated due to cache expiry""" @@ -286,30 +423,32 @@ def load_features( ) -> Optional[Dict]: if not client_key: raise ValueError("Must specify `client_key` to refresh features") - + key = api_host + "::" + client_key - cached = self.cache.get(key) + cached = self.feature_cache.get(key) if not cached: res = self._fetch_features(api_host, client_key, decryption_key) if res is not None: - self.cache.set(key, res, ttl) + self.feature_cache.set(key, res, ttl) + self.feature_cache.update_secondary_cache_from_primary() logger.debug("Fetched features from API, stored in cache") # Notify callbacks about fresh features self._notify_feature_update_callbacks(res) return res return cached - + async def load_features_async( self, api_host: str, client_key: str, decryption_key: str = "", ttl: int = 600 ) -> Optional[Dict]: key = api_host + "::" + client_key - cached = self.cache.get(key) + cached = self.feature_cache.get(key) if not cached: res = await self._fetch_features_async(api_host, client_key, decryption_key) if res is not None: - self.cache.set(key, res, ttl) + self.feature_cache.set(key, res, ttl) + self.feature_cache.update_secondary_cache_from_primary() logger.debug("Fetched features from API, stored in cache") # Notify callbacks about fresh features self._notify_feature_update_callbacks(res) @@ -320,7 +459,7 @@ async def load_features_async( def _get(self, url: str): self.http = self.http or PoolManager() return self.http.request("GET", url) - + def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: try: r = self._get(self._get_features_url(api_host, client_key)) @@ -334,7 +473,7 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: except Exception: logger.warning("Failed to decode feature JSON from GrowthBook API") return None - + async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optional[Dict]: try: url = self._get_features_url(api_host, client_key) @@ -351,7 +490,7 @@ async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optio except Exception as e: logger.warning("Failed to decode feature JSON from GrowthBook API: %s", e) return None - + def decrypt_response(self, data, decryption_key: str): if "encryptedFeatures" in data: if not decryption_key: @@ -367,7 +506,7 @@ def decrypt_response(self, data, decryption_key: str): return None elif "features" not in data: logger.warning("GrowthBook API response missing features") - + if "encryptedSavedGroups" in data: if not decryption_key: raise ValueError("Must specify decryption_key") @@ -380,7 +519,7 @@ def decrypt_response(self, data, decryption_key: str): logger.warning( "Failed to decrypt saved groups from GrowthBook API response" ) - + return data # Fetch features from the GrowthBook API @@ -394,7 +533,7 @@ def _fetch_features( data = self.decrypt_response(decoded, decryption_key) return data - + async def _fetch_features_async( self, api_host: str, client_key: str, decryption_key: str = "" ) -> Optional[Dict]: @@ -438,6 +577,7 @@ def __init__( client_key: str = "", decryption_key: str = "", cache_ttl: int = 600, + persistent_cache_base_dir: Optional[str] = None, forced_variations: dict = {}, sticky_bucket_service: AbstractStickyBucketService = None, sticky_bucket_identifier_attributes: List[str] = None, @@ -500,7 +640,7 @@ def __init__( ), features={}, saved_groups=self._saved_groups - ) + ) # Create a user context for the current user self._user_ctx: UserContext = UserContext( url=self._url, @@ -511,6 +651,8 @@ def __init__( sticky_bucket_assignment_docs=self._sticky_bucket_assignment_docs ) + feature_repo.load_features_from_persistent_cache() + if features: self.setFeatures(features) @@ -528,6 +670,8 @@ def _on_feature_update(self, features_data: Dict) -> None: """Callback to handle automatic feature updates from FeatureRepository""" if features_data and "features" in features_data: self.set_features(features_data["features"]) + feature_repo.save_in_cache(self._client_key, features_data, self._cache_ttl) + if features_data and "savedGroups" in features_data: self._saved_groups = features_data["savedGroups"] @@ -561,7 +705,7 @@ def _features_event_handler(self, features): decoded = json.loads(features) if not decoded: return None - + data = feature_repo.decrypt_response(decoded, self._decryption_key) if data is not None: @@ -583,9 +727,9 @@ def _dispatch_sse_event(self, event_data): def startAutoRefresh(self): if not self._client_key: raise ValueError("Must specify `client_key` to start features streaming") - + feature_repo.startAutoRefresh( - api_host=self._api_host, + api_host=self._api_host, client_key=self._client_key, cb=self._dispatch_sse_event ) @@ -637,11 +781,11 @@ def get_attributes(self) -> dict: def destroy(self) -> None: # Clean up plugins first self._cleanup_plugins() - + # Clean up feature update callback if self._client_key: feature_repo.remove_feature_update_callback(self._on_feature_update) - + self._subscriptions.clear() self._tracked.clear() self._assigned.clear() @@ -677,10 +821,10 @@ def get_feature_value(self, key: str, fallback): # @deprecated, use eval_feature def evalFeature(self, key: str) -> FeatureResult: return self.eval_feature(key) - + def _ensure_fresh_features(self) -> None: """Lazy refresh: Check cache expiry and refresh if needed, but only if client_key is provided""" - + if self._streaming or not self._client_key: return # Skip cache checks - SSE handles freshness for streaming users @@ -692,7 +836,7 @@ def _ensure_fresh_features(self) -> None: def _get_eval_context(self) -> EvaluationContext: # Lazy refresh: ensure features are fresh before evaluation self._ensure_fresh_features() - + # use the latest attributes for every evaluation. self._user_ctx.attributes = self._attributes self._user_ctx.url = self._url @@ -706,8 +850,8 @@ def _get_eval_context(self) -> EvaluationContext: ) def eval_feature(self, key: str) -> FeatureResult: - return core_eval_feature(key=key, - evalContext=self._get_eval_context(), + return core_eval_feature(key=key, + evalContext=self._get_eval_context(), callback_subscription=self._fireSubscriptions, tracking_cb=self._track ) @@ -722,7 +866,7 @@ def get_all_results(self): def _fireSubscriptions(self, experiment: Experiment, result: Result): if experiment is None: return - + prev = self._assigned.get(experiment.key, None) if ( not prev @@ -741,7 +885,7 @@ def _fireSubscriptions(self, experiment: Experiment, result: Result): def run(self, experiment: Experiment) -> Result: # result = self._run(experiment) - result = run_experiment(experiment=experiment, + result = run_experiment(experiment=experiment, evalContext=self._get_eval_context(), tracking_cb=self._track ) diff --git a/growthbook/growthbook_client.py b/growthbook/growthbook_client.py index 679f17e..ac577db 100644 --- a/growthbook/growthbook_client.py +++ b/growthbook/growthbook_client.py @@ -405,6 +405,8 @@ async def initialize(self) -> bool: if not self._features_repository: logger.error("No features repository available") return False + + self._features_repository.load_features_from_persistent_cache() try: # Initial feature load diff --git a/tests/test_growthbook.py b/tests/test_growthbook.py index c310bbd..77b3bc0 100644 --- a/tests/test_growthbook.py +++ b/tests/test_growthbook.py @@ -160,7 +160,7 @@ def test_stickyBucket(stickyBucket_data): gb = GrowthBook(**ctx) res = gb.eval_feature(key) - + if not res.experimentResult: assert None == expected_result else: @@ -656,6 +656,8 @@ def __init__(self, status: int, data: str) -> None: def test_feature_repository(mocker): + feature_repo.clear_cache() + m = mocker.patch.object(feature_repo, "_get") expected = {"features": {"feature": {"defaultValue": 5}}} m.return_value = MockHttpResp(200, json.dumps(expected)) @@ -670,9 +672,21 @@ def test_feature_repository(mocker): assert features == expected # Does a new request if cache entry is expired - feature_repo.cache.cache["https://cdn.growthbook.io::sdk-abc123"].expires = ( - time() - 10 - ) + cache_key_to_expire = "https://cdn.growthbook.io::sdk-abc123" + + # Access the _primary_cache (which is InMemoryFeatureCache) and then its internal _cache dict + # Use getattr for safer access in tests, though direct access is fine if you're sure of the structure + if hasattr(feature_repo.feature_cache, '_primary_cache') and \ + hasattr(feature_repo.feature_cache._primary_cache, '_cache') and \ + cache_key_to_expire in feature_repo.feature_cache._primary_cache._cache: + + feature_repo.feature_cache._primary_cache._cache[cache_key_to_expire].expires = (time() - 10) + feature_repo.feature_cache._secondary_cache._cache[cache_key_to_expire].expires = (time() - 10) + logger.debug(f"Manually expired cache key: {cache_key_to_expire}") + else: + logger.warning( + f"Failed to manually expire cache key {cache_key_to_expire}. Cache structure might have changed or key not found.") + features = feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") assert m.call_count == 2 assert features == expected @@ -681,6 +695,8 @@ def test_feature_repository(mocker): def test_feature_repository_error(mocker): + feature_repo.clear_cache() + m = mocker.patch.object(feature_repo, "_get") m.return_value = MockHttpResp(400, "400 Error") features = feature_repo.load_features("https://cdn.growthbook.io", "sdk-abc123") @@ -703,6 +719,8 @@ def test_feature_repository_error(mocker): def test_feature_repository_encrypted(mocker): + feature_repo.clear_cache() + m = mocker.patch.object(feature_repo, "_get") m.return_value = MockHttpResp( 200, @@ -728,6 +746,8 @@ def test_feature_repository_encrypted(mocker): def test_load_features(mocker): + feature_repo.clear_cache() + m = mocker.patch.object(feature_repo, "_get") m.return_value = MockHttpResp( 200, json.dumps({"features": {"feature": {"defaultValue": 5}}}) @@ -747,6 +767,8 @@ def test_load_features(mocker): def test_loose_unmarshalling(mocker): + feature_repo.clear_cache() + m = mocker.patch.object(feature_repo, "_get") m.return_value = MockHttpResp(200, json.dumps({ "features": { @@ -921,43 +943,56 @@ def test_sticky_bucket_service(mocker): def test_ttl_automatic_feature_refresh(mocker): """Test that GrowthBook instances automatically get updated features when cache expires during evaluation""" # Mock responses to simulate feature flag changes + feature_repo.clear_cache() + mock_responses = [ {"features": {"test_feature": {"defaultValue": False}}, "savedGroups": {}}, {"features": {"test_feature": {"defaultValue": True}}, "savedGroups": {}} ] - + call_count = 0 def mock_fetch_features(api_host, client_key, decryption_key=""): nonlocal call_count response = mock_responses[min(call_count, len(mock_responses) - 1)] call_count += 1 return response - + # Clear cache and mock the fetch method feature_repo.clear_cache() m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) - + # Create GrowthBook instance with short TTL gb = GrowthBook( api_host="https://cdn.growthbook.io", client_key="test-key", cache_ttl=1 # 1 second TTL for testing ) - + try: # Initial evaluation - should trigger first load assert gb.is_on('test_feature') == False assert call_count == 1 - + # Manually expire the cache by setting expiry time to past - cache_key = "https://cdn.growthbook.io::test-key" - if hasattr(feature_repo.cache, 'cache') and cache_key in feature_repo.cache.cache: - feature_repo.cache.cache[cache_key].expires = time() - 10 - + cache_key_to_expire = "https://cdn.growthbook.io::test-key" + + # Access the _primary_cache (which is InMemoryFeatureCache) and then its internal _cache dict + # Use getattr for safer access in tests, though direct access is fine if you're sure of the structure + if hasattr(feature_repo.feature_cache, '_primary_cache') and \ + hasattr(feature_repo.feature_cache._primary_cache, '_cache') and \ + cache_key_to_expire in feature_repo.feature_cache._primary_cache._cache: + + feature_repo.feature_cache._primary_cache._cache[cache_key_to_expire].expires = (time() - 10) + feature_repo.feature_cache._secondary_cache._cache[cache_key_to_expire].expires = (time() - 10) + logger.debug(f"Manually expired cache key: {cache_key_to_expire}") + else: + logger.warning( + f"Failed to manually expire cache key {cache_key_to_expire}. Cache structure might have changed or key not found.") + # Next evaluation should automatically refresh cache and update features assert gb.is_on('test_feature') == True assert call_count == 2 - + finally: gb.destroy() feature_repo.clear_cache() @@ -965,47 +1000,61 @@ def mock_fetch_features(api_host, client_key, decryption_key=""): def test_multiple_instances_get_updated_on_cache_expiry(mocker): """Test that multiple GrowthBook instances all get updated when cache expires during evaluation""" + feature_repo.clear_cache() + mock_responses = [ {"features": {"test_feature": {"defaultValue": "v1"}}, "savedGroups": {}}, {"features": {"test_feature": {"defaultValue": "v2"}}, "savedGroups": {}} ] - + call_count = 0 def mock_fetch_features(api_host, client_key, decryption_key=""): nonlocal call_count response = mock_responses[min(call_count, len(mock_responses) - 1)] call_count += 1 return response - + feature_repo.clear_cache() m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) - + # Create multiple GrowthBook instances gb1 = GrowthBook(api_host="https://cdn.growthbook.io", client_key="test-key") gb2 = GrowthBook(api_host="https://cdn.growthbook.io", client_key="test-key") - + try: # Initial evaluation from first instance - should trigger first load assert gb1.get_feature_value('test_feature', 'default') == "v1" assert call_count == 1 - + # Second instance should use cached value (no additional API call) assert gb2.get_feature_value('test_feature', 'default') == "v1" assert call_count == 1 # Still 1, used cache - + # Manually expire the cache - cache_key = "https://cdn.growthbook.io::test-key" - if hasattr(feature_repo.cache, 'cache') and cache_key in feature_repo.cache.cache: - feature_repo.cache.cache[cache_key].expires = time() - 10 - + cache_key_to_expire = "https://cdn.growthbook.io::test-key" + + # Access the _primary_cache (which is InMemoryFeatureCache) and then its internal _cache dict + # Use getattr for safer access in tests, though direct access is fine if you're sure of the structure + if hasattr(feature_repo.feature_cache, '_primary_cache') and \ + hasattr(feature_repo.feature_cache._primary_cache, '_cache') and \ + cache_key_to_expire in feature_repo.feature_cache._primary_cache._cache: + + feature_repo.feature_cache._primary_cache._cache[cache_key_to_expire].expires = (time() - 10) + feature_repo.feature_cache._secondary_cache._cache[cache_key_to_expire].expires = (time() - 10) + + logger.debug(f"Manually expired cache key: {cache_key_to_expire}") + else: + logger.warning( + f"Failed to manually expire cache key {cache_key_to_expire}. Cache structure might have changed or key not found.") + # Next evaluation should automatically refresh and notify both instances via callbacks assert gb1.get_feature_value('test_feature', 'default') == "v2" assert call_count == 2 - + # Second instance should also have the updated value due to callbacks assert gb2.get_feature_value('test_feature', 'default') == "v2" - + finally: gb1.destroy() gb2.destroy() - feature_repo.clear_cache() \ No newline at end of file + feature_repo.clear_cache()