Skip to content

Commit cd8dd95

Browse files
authored
feat: add stream flag (#53)
* Add stream feat * Use EvaluationFlag, fix keep alive ms, rename defaults, fix tests * Remove unused imports * Fix union type * Fix typing
1 parent d60913b commit cd8dd95

File tree

12 files changed

+834
-80
lines changed

12 files changed

+834
-80
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
amplitude_analytics~=1.1.1
22
dataclasses-json~=0.6.7
3+
sseclient-py~=1.8.0

src/amplitude_experiment/deployment/deployment_runner.py

Lines changed: 27 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
from typing import Optional
33
import threading
44

5+
from ..flag.flag_config_updater import FlagConfigPoller, FlagConfigStreamer, FlagConfigUpdaterFallbackRetryWrapper
56
from ..local.config import LocalEvaluationConfig
67
from ..cohort.cohort_loader import CohortLoader
78
from ..cohort.cohort_storage import CohortStorage
8-
from ..flag.flag_config_api import FlagConfigApi
9+
from ..flag.flag_config_api import FlagConfigApi, FlagConfigStreamApi
910
from ..flag.flag_config_storage import FlagConfigStorage
1011
from ..local.poller import Poller
11-
from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags
12+
from ..util.flag_config import get_all_cohort_ids_from_flags
13+
14+
DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS = 15000
15+
DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS = 1000
1216

1317

1418
class DeploymentRunner:
1519
def __init__(
1620
self,
1721
config: LocalEvaluationConfig,
1822
flag_config_api: FlagConfigApi,
23+
flag_config_stream_api: Optional[FlagConfigStreamApi],
1924
flag_config_storage: FlagConfigStorage,
2025
cohort_storage: CohortStorage,
2126
logger: logging.Logger,
@@ -27,88 +32,41 @@ def __init__(
2732
self.cohort_storage = cohort_storage
2833
self.cohort_loader = cohort_loader
2934
self.lock = threading.Lock()
30-
self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update)
35+
self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper(
36+
FlagConfigPoller(flag_config_api, flag_config_storage, cohort_loader, cohort_storage, config, logger),
37+
None,
38+
0, 0, config.flag_config_polling_interval_millis, 0,
39+
logger
40+
)
41+
if flag_config_stream_api:
42+
self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper(
43+
FlagConfigStreamer(flag_config_stream_api, flag_config_storage, cohort_loader, cohort_storage, logger),
44+
self.flag_updater,
45+
DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS, DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS,
46+
config.flag_config_polling_interval_millis, 0,
47+
logger
48+
)
49+
50+
self.cohort_poller = None
3151
if self.cohort_loader:
3252
self.cohort_poller = Poller(self.config.cohort_sync_config.cohort_polling_interval_millis / 1000,
3353
self.__update_cohorts)
3454
self.logger = logger
3555

3656
def start(self):
3757
with self.lock:
38-
self.__update_flag_configs()
39-
self.flag_poller.start()
58+
self.flag_updater.start(None)
4059
if self.cohort_loader:
4160
self.cohort_poller.start()
4261

4362
def stop(self):
44-
self.flag_poller.stop()
45-
46-
def __periodic_flag_update(self):
47-
try:
48-
self.__update_flag_configs()
49-
except Exception as e:
50-
self.logger.warning(f"Error while updating flags: {e}")
51-
52-
def __update_flag_configs(self):
53-
try:
54-
flag_configs = self.flag_config_api.get_flag_configs()
55-
except Exception as e:
56-
self.logger.warning(f'Failed to fetch flag configs: {e}')
57-
raise e
58-
59-
flag_keys = {flag.key for flag in flag_configs}
60-
self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys)
61-
62-
if not self.cohort_loader:
63-
for flag_config in flag_configs:
64-
self.logger.debug(f"Putting non-cohort flag {flag_config.key}")
65-
self.flag_config_storage.put_flag_config(flag_config)
66-
return
67-
68-
new_cohort_ids = set()
69-
for flag_config in flag_configs:
70-
new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config))
71-
72-
existing_cohort_ids = self.cohort_storage.get_cohort_ids()
73-
cohort_ids_to_download = new_cohort_ids - existing_cohort_ids
74-
75-
# download all new cohorts
76-
try:
77-
self.cohort_loader.download_cohorts(cohort_ids_to_download).result()
78-
except Exception as e:
79-
self.logger.warning(f"Error while downloading cohorts: {e}")
80-
81-
# get updated set of cohort ids
82-
updated_cohort_ids = self.cohort_storage.get_cohort_ids()
83-
# iterate through new flag configs and check if their required cohorts exist
84-
for flag_config in flag_configs:
85-
cohort_ids = get_all_cohort_ids_from_flag(flag_config)
86-
self.logger.debug(f"Storing flag {flag_config.key}")
87-
self.flag_config_storage.put_flag_config(flag_config)
88-
missing_cohorts = cohort_ids - updated_cohort_ids
89-
if missing_cohorts:
90-
self.logger.warning(f"Flag {flag_config.key} - failed to load cohorts: {missing_cohorts}")
91-
92-
# delete unused cohorts
93-
self._delete_unused_cohorts()
94-
self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.")
63+
self.flag_updater.stop()
64+
if self.cohort_poller:
65+
self.cohort_poller.stop()
9566

9667
def __update_cohorts(self):
9768
cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values()))
9869
try:
9970
self.cohort_loader.download_cohorts(cohort_ids).result()
10071
except Exception as e:
10172
self.logger.warning(f"Error while updating cohorts: {e}")
102-
103-
def _delete_unused_cohorts(self):
104-
flag_cohort_ids = set()
105-
for flag in self.flag_config_storage.get_flag_configs().values():
106-
flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag))
107-
108-
storage_cohorts = self.cohort_storage.get_cohorts()
109-
deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids
110-
111-
for deleted_cohort_id in deleted_cohort_ids:
112-
deleted_cohort = storage_cohorts.get(deleted_cohort_id)
113-
if deleted_cohort is not None:
114-
self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .flag_config_api import FlagConfigStreamApi
2+
from .flag_config_updater import FlagConfigStreamer

src/amplitude_experiment/flag/flag_config_api.py

Lines changed: 182 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
2-
from typing import List
2+
import threading
3+
from http.client import HTTPResponse, HTTPConnection, HTTPSConnection
4+
from typing import List, Optional, Callable, Mapping, Union, Tuple
35

4-
from ..evaluation.types import EvaluationFlag
5-
from ..version import __version__
6+
import sseclient
67

78
from ..connection_pool import HTTPConnectionPool
8-
9+
from ..util.updater import get_duration_with_jitter
10+
from ..evaluation.types import EvaluationFlag
11+
from ..version import __version__
912

1013
class FlagConfigApi:
1114
def get_flag_configs(self) -> List[EvaluationFlag]:
@@ -46,3 +49,178 @@ def __setup_connection_pool(self):
4649
timeout = self.flag_config_poller_request_timeout_millis / 1000
4750
self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30,
4851
read_timeout=timeout, scheme=scheme)
52+
53+
54+
DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS = 17000
55+
DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS = 15 * 60 * 1000
56+
DEFAULT_STREAM_MAX_JITTER_MILLIS = 5000
57+
58+
59+
class EventSource:
60+
def __init__(self, server_url: str, path: str, headers: Mapping[str, str], conn_timeout_millis: int,
61+
max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS,
62+
max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS,
63+
keep_alive_timeout_millis: int = DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS):
64+
self.keep_alive_timer: Optional[threading.Timer] = None
65+
self.server_url = server_url
66+
self.path = path
67+
self.headers = headers
68+
self.conn_timeout_millis = conn_timeout_millis
69+
self.max_conn_duration_millis = max_conn_duration_millis
70+
self.max_jitter_millis = max_jitter_millis
71+
self.keep_alive_timeout_millis = keep_alive_timeout_millis
72+
73+
self.sse: Optional[sseclient.SSEClient] = None
74+
self.conn: Optional[HTTPConnection | HTTPSConnection] = None
75+
self.thread: Optional[threading.Thread] = None
76+
self._stopped = False
77+
self.lock = threading.RLock()
78+
79+
def start(self, on_update: Callable[[str], None], on_error: Callable[[str], None]):
80+
with self.lock:
81+
if self.sse is not None:
82+
self.sse.close()
83+
if self.conn is not None:
84+
self.conn.close()
85+
86+
self.conn, response = self._get_conn()
87+
if response.status != 200:
88+
on_error(f"[Experiment] Stream flagConfigs - received error response: ${response.status}: ${response.read().decode('utf-8')}")
89+
return
90+
91+
self.sse = sseclient.SSEClient(response, char_enc='utf-8')
92+
self._stopped = False
93+
self.thread = threading.Thread(target=self._run, args=[on_update, on_error])
94+
self.thread.start()
95+
self.reset_keep_alive_timer(on_error)
96+
97+
def stop(self):
98+
with self.lock:
99+
self._stopped = True
100+
if self.sse:
101+
self.sse.close()
102+
if self.conn:
103+
self.conn.close()
104+
if self.keep_alive_timer:
105+
self.keep_alive_timer.cancel()
106+
self.sse = None
107+
self.conn = None
108+
# No way to stop self.thread, on self.conn.close(),
109+
# the loop in thread will raise exception, which will terminate the thread.
110+
111+
def reset_keep_alive_timer(self, on_error: Callable[[str], None]):
112+
with self.lock:
113+
if self.keep_alive_timer:
114+
self.keep_alive_timer.cancel()
115+
self.keep_alive_timer = threading.Timer(self.keep_alive_timeout_millis / 1000, self.keep_alive_timed_out,
116+
args=[on_error])
117+
self.keep_alive_timer.start()
118+
119+
def keep_alive_timed_out(self, on_error: Callable[[str], None]):
120+
with self.lock:
121+
if not self._stopped:
122+
self.stop()
123+
on_error("[Experiment] Stream flagConfigs - Keep alive timed out")
124+
125+
def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None]):
126+
try:
127+
for event in self.sse.events():
128+
with self.lock:
129+
if self._stopped:
130+
return
131+
self.reset_keep_alive_timer(on_error)
132+
if event.data == ' ':
133+
continue
134+
on_update(event.data)
135+
except TimeoutError:
136+
# Due to connection max time reached, open another one.
137+
with self.lock:
138+
if self._stopped:
139+
return
140+
self.stop()
141+
self.start(on_update, on_error)
142+
except Exception as e:
143+
# Closing connection can result in exception here as a way to stop generator.
144+
with self.lock:
145+
if self._stopped:
146+
return
147+
on_error("[Experiment] Stream flagConfigs - Unexpected exception" + str(e))
148+
149+
def _get_conn(self) -> Tuple[Union[HTTPConnection, HTTPSConnection], HTTPResponse]:
150+
scheme, _, host = self.server_url.split('/', 3)
151+
connection = HTTPConnection if scheme == 'http:' else HTTPSConnection
152+
153+
body = None
154+
155+
conn = connection(host, timeout=get_duration_with_jitter(self.max_conn_duration_millis, self.max_jitter_millis) / 1000)
156+
try:
157+
conn.request('GET', self.path, body, self.headers)
158+
response = conn.getresponse()
159+
except Exception as e:
160+
conn.close()
161+
raise e
162+
163+
return conn, response
164+
165+
166+
class FlagConfigStreamApi:
167+
def __init__(self,
168+
deployment_key: str,
169+
server_url: str,
170+
conn_timeout_millis: int,
171+
max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS,
172+
max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS):
173+
self.deployment_key = deployment_key
174+
self.server_url = server_url
175+
self.conn_timeout_millis = conn_timeout_millis
176+
self.max_conn_duration_millis = max_conn_duration_millis
177+
self.max_jitter_millis = max_jitter_millis
178+
179+
self.lock = threading.RLock()
180+
181+
headers = {
182+
'Authorization': f"Api-Key {self.deployment_key}",
183+
'Content-Type': 'application/json;charset=utf-8',
184+
'X-Amp-Exp-Library': f"experiment-python-server/{__version__}"
185+
}
186+
187+
self.eventsource = EventSource(self.server_url, "/sdk/stream/v1/flags", headers, conn_timeout_millis)
188+
189+
def start(self, on_update: Callable[[List[EvaluationFlag]], None], on_error: Callable[[str], None]):
190+
with self.lock:
191+
init_finished_event = threading.Event()
192+
init_error_event = threading.Event()
193+
init_updated_event = threading.Event()
194+
195+
def _on_update(data):
196+
response_json = json.loads(data)
197+
flags = EvaluationFlag.schema().load(response_json, many=True)
198+
if init_finished_event.is_set():
199+
on_update(flags)
200+
else:
201+
init_finished_event.set()
202+
on_update(flags)
203+
init_updated_event.set()
204+
205+
def _on_error(data):
206+
if init_finished_event.is_set():
207+
on_error(data)
208+
else:
209+
init_error_event.set()
210+
init_finished_event.set()
211+
on_error(data)
212+
213+
t = threading.Thread(target=self.eventsource.start, args=[_on_update, _on_error])
214+
t.start()
215+
init_finished_event.wait(self.conn_timeout_millis / 1000)
216+
if t.is_alive() or not init_finished_event.is_set() or init_error_event.is_set():
217+
self.stop()
218+
on_error("stream connection timeout error")
219+
return
220+
221+
# Wait for first update callback to finish before returning.
222+
init_updated_event.wait()
223+
224+
def stop(self):
225+
with self.lock:
226+
threading.Thread(target=lambda: self.eventsource.stop()).start()

0 commit comments

Comments
 (0)