|
1 | 1 | import json |
2 | | -from typing import List |
| 2 | +import threading |
| 3 | +from http.client import HTTPResponse, HTTPConnection, HTTPSConnection |
| 4 | +from typing import List, Optional, Callable, Mapping, Union, Tuple |
3 | 5 |
|
4 | | -from ..evaluation.types import EvaluationFlag |
5 | | -from ..version import __version__ |
| 6 | +import sseclient |
6 | 7 |
|
7 | 8 | from ..connection_pool import HTTPConnectionPool |
8 | | - |
| 9 | +from ..util.updater import get_duration_with_jitter |
| 10 | +from ..evaluation.types import EvaluationFlag |
| 11 | +from ..version import __version__ |
9 | 12 |
|
10 | 13 | class FlagConfigApi: |
11 | 14 | def get_flag_configs(self) -> List[EvaluationFlag]: |
@@ -46,3 +49,178 @@ def __setup_connection_pool(self): |
46 | 49 | timeout = self.flag_config_poller_request_timeout_millis / 1000 |
47 | 50 | self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30, |
48 | 51 | read_timeout=timeout, scheme=scheme) |
| 52 | + |
| 53 | + |
| 54 | +DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS = 17000 |
| 55 | +DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS = 15 * 60 * 1000 |
| 56 | +DEFAULT_STREAM_MAX_JITTER_MILLIS = 5000 |
| 57 | + |
| 58 | + |
| 59 | +class EventSource: |
| 60 | + def __init__(self, server_url: str, path: str, headers: Mapping[str, str], conn_timeout_millis: int, |
| 61 | + max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS, |
| 62 | + max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS, |
| 63 | + keep_alive_timeout_millis: int = DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS): |
| 64 | + self.keep_alive_timer: Optional[threading.Timer] = None |
| 65 | + self.server_url = server_url |
| 66 | + self.path = path |
| 67 | + self.headers = headers |
| 68 | + self.conn_timeout_millis = conn_timeout_millis |
| 69 | + self.max_conn_duration_millis = max_conn_duration_millis |
| 70 | + self.max_jitter_millis = max_jitter_millis |
| 71 | + self.keep_alive_timeout_millis = keep_alive_timeout_millis |
| 72 | + |
| 73 | + self.sse: Optional[sseclient.SSEClient] = None |
| 74 | + self.conn: Optional[HTTPConnection | HTTPSConnection] = None |
| 75 | + self.thread: Optional[threading.Thread] = None |
| 76 | + self._stopped = False |
| 77 | + self.lock = threading.RLock() |
| 78 | + |
| 79 | + def start(self, on_update: Callable[[str], None], on_error: Callable[[str], None]): |
| 80 | + with self.lock: |
| 81 | + if self.sse is not None: |
| 82 | + self.sse.close() |
| 83 | + if self.conn is not None: |
| 84 | + self.conn.close() |
| 85 | + |
| 86 | + self.conn, response = self._get_conn() |
| 87 | + if response.status != 200: |
| 88 | + on_error(f"[Experiment] Stream flagConfigs - received error response: ${response.status}: ${response.read().decode('utf-8')}") |
| 89 | + return |
| 90 | + |
| 91 | + self.sse = sseclient.SSEClient(response, char_enc='utf-8') |
| 92 | + self._stopped = False |
| 93 | + self.thread = threading.Thread(target=self._run, args=[on_update, on_error]) |
| 94 | + self.thread.start() |
| 95 | + self.reset_keep_alive_timer(on_error) |
| 96 | + |
| 97 | + def stop(self): |
| 98 | + with self.lock: |
| 99 | + self._stopped = True |
| 100 | + if self.sse: |
| 101 | + self.sse.close() |
| 102 | + if self.conn: |
| 103 | + self.conn.close() |
| 104 | + if self.keep_alive_timer: |
| 105 | + self.keep_alive_timer.cancel() |
| 106 | + self.sse = None |
| 107 | + self.conn = None |
| 108 | + # No way to stop self.thread, on self.conn.close(), |
| 109 | + # the loop in thread will raise exception, which will terminate the thread. |
| 110 | + |
| 111 | + def reset_keep_alive_timer(self, on_error: Callable[[str], None]): |
| 112 | + with self.lock: |
| 113 | + if self.keep_alive_timer: |
| 114 | + self.keep_alive_timer.cancel() |
| 115 | + self.keep_alive_timer = threading.Timer(self.keep_alive_timeout_millis / 1000, self.keep_alive_timed_out, |
| 116 | + args=[on_error]) |
| 117 | + self.keep_alive_timer.start() |
| 118 | + |
| 119 | + def keep_alive_timed_out(self, on_error: Callable[[str], None]): |
| 120 | + with self.lock: |
| 121 | + if not self._stopped: |
| 122 | + self.stop() |
| 123 | + on_error("[Experiment] Stream flagConfigs - Keep alive timed out") |
| 124 | + |
| 125 | + def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None]): |
| 126 | + try: |
| 127 | + for event in self.sse.events(): |
| 128 | + with self.lock: |
| 129 | + if self._stopped: |
| 130 | + return |
| 131 | + self.reset_keep_alive_timer(on_error) |
| 132 | + if event.data == ' ': |
| 133 | + continue |
| 134 | + on_update(event.data) |
| 135 | + except TimeoutError: |
| 136 | + # Due to connection max time reached, open another one. |
| 137 | + with self.lock: |
| 138 | + if self._stopped: |
| 139 | + return |
| 140 | + self.stop() |
| 141 | + self.start(on_update, on_error) |
| 142 | + except Exception as e: |
| 143 | + # Closing connection can result in exception here as a way to stop generator. |
| 144 | + with self.lock: |
| 145 | + if self._stopped: |
| 146 | + return |
| 147 | + on_error("[Experiment] Stream flagConfigs - Unexpected exception" + str(e)) |
| 148 | + |
| 149 | + def _get_conn(self) -> Tuple[Union[HTTPConnection, HTTPSConnection], HTTPResponse]: |
| 150 | + scheme, _, host = self.server_url.split('/', 3) |
| 151 | + connection = HTTPConnection if scheme == 'http:' else HTTPSConnection |
| 152 | + |
| 153 | + body = None |
| 154 | + |
| 155 | + conn = connection(host, timeout=get_duration_with_jitter(self.max_conn_duration_millis, self.max_jitter_millis) / 1000) |
| 156 | + try: |
| 157 | + conn.request('GET', self.path, body, self.headers) |
| 158 | + response = conn.getresponse() |
| 159 | + except Exception as e: |
| 160 | + conn.close() |
| 161 | + raise e |
| 162 | + |
| 163 | + return conn, response |
| 164 | + |
| 165 | + |
| 166 | +class FlagConfigStreamApi: |
| 167 | + def __init__(self, |
| 168 | + deployment_key: str, |
| 169 | + server_url: str, |
| 170 | + conn_timeout_millis: int, |
| 171 | + max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS, |
| 172 | + max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS): |
| 173 | + self.deployment_key = deployment_key |
| 174 | + self.server_url = server_url |
| 175 | + self.conn_timeout_millis = conn_timeout_millis |
| 176 | + self.max_conn_duration_millis = max_conn_duration_millis |
| 177 | + self.max_jitter_millis = max_jitter_millis |
| 178 | + |
| 179 | + self.lock = threading.RLock() |
| 180 | + |
| 181 | + headers = { |
| 182 | + 'Authorization': f"Api-Key {self.deployment_key}", |
| 183 | + 'Content-Type': 'application/json;charset=utf-8', |
| 184 | + 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" |
| 185 | + } |
| 186 | + |
| 187 | + self.eventsource = EventSource(self.server_url, "/sdk/stream/v1/flags", headers, conn_timeout_millis) |
| 188 | + |
| 189 | + def start(self, on_update: Callable[[List[EvaluationFlag]], None], on_error: Callable[[str], None]): |
| 190 | + with self.lock: |
| 191 | + init_finished_event = threading.Event() |
| 192 | + init_error_event = threading.Event() |
| 193 | + init_updated_event = threading.Event() |
| 194 | + |
| 195 | + def _on_update(data): |
| 196 | + response_json = json.loads(data) |
| 197 | + flags = EvaluationFlag.schema().load(response_json, many=True) |
| 198 | + if init_finished_event.is_set(): |
| 199 | + on_update(flags) |
| 200 | + else: |
| 201 | + init_finished_event.set() |
| 202 | + on_update(flags) |
| 203 | + init_updated_event.set() |
| 204 | + |
| 205 | + def _on_error(data): |
| 206 | + if init_finished_event.is_set(): |
| 207 | + on_error(data) |
| 208 | + else: |
| 209 | + init_error_event.set() |
| 210 | + init_finished_event.set() |
| 211 | + on_error(data) |
| 212 | + |
| 213 | + t = threading.Thread(target=self.eventsource.start, args=[_on_update, _on_error]) |
| 214 | + t.start() |
| 215 | + init_finished_event.wait(self.conn_timeout_millis / 1000) |
| 216 | + if t.is_alive() or not init_finished_event.is_set() or init_error_event.is_set(): |
| 217 | + self.stop() |
| 218 | + on_error("stream connection timeout error") |
| 219 | + return |
| 220 | + |
| 221 | + # Wait for first update callback to finish before returning. |
| 222 | + init_updated_event.wait() |
| 223 | + |
| 224 | + def stop(self): |
| 225 | + with self.lock: |
| 226 | + threading.Thread(target=lambda: self.eventsource.stop()).start() |
0 commit comments