|
| 1 | +import time |
| 2 | +import logging |
| 3 | +import base64 |
| 4 | +import json |
| 5 | +from http.client import HTTPResponse |
| 6 | +from typing import Optional |
| 7 | +from ..version import __version__ |
| 8 | + |
| 9 | +from .cohort import Cohort |
| 10 | +from ..connection_pool import HTTPConnectionPool |
| 11 | +from ..exception import HTTPErrorResponseException, CohortTooLargeException |
| 12 | + |
| 13 | +COHORT_REQUEST_RETRY_DELAY_MILLIS = 100 |
| 14 | + |
| 15 | + |
| 16 | +class CohortDownloadApi: |
| 17 | + |
| 18 | + def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: |
| 19 | + raise NotImplementedError |
| 20 | + |
| 21 | + |
| 22 | +class DirectCohortDownloadApi(CohortDownloadApi): |
| 23 | + def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, server_url: str, logger: logging.Logger): |
| 24 | + super().__init__() |
| 25 | + self.api_key = api_key |
| 26 | + self.secret_key = secret_key |
| 27 | + self.max_cohort_size = max_cohort_size |
| 28 | + self.server_url = server_url |
| 29 | + self.logger = logger |
| 30 | + self.__setup_connection_pool() |
| 31 | + |
| 32 | + def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None: |
| 33 | + self.logger.debug(f"getCohortMembers({cohort_id}): start") |
| 34 | + errors = 0 |
| 35 | + while True: |
| 36 | + response = None |
| 37 | + try: |
| 38 | + last_modified = None if cohort is None else cohort.last_modified |
| 39 | + response = self._get_cohort_members_request(cohort_id, last_modified) |
| 40 | + self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}") |
| 41 | + if response.status == 200: |
| 42 | + cohort_info = json.loads(response.read().decode("utf8")) |
| 43 | + self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}") |
| 44 | + return Cohort( |
| 45 | + id=cohort_info['cohortId'], |
| 46 | + last_modified=cohort_info['lastModified'], |
| 47 | + size=cohort_info['size'], |
| 48 | + member_ids=set(cohort_info['memberIds']), |
| 49 | + group_type=cohort_info['groupType'], |
| 50 | + ) |
| 51 | + elif response.status == 204: |
| 52 | + self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified") |
| 53 | + return |
| 54 | + elif response.status == 413: |
| 55 | + raise CohortTooLargeException( |
| 56 | + f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}") |
| 57 | + elif response.status != 202: |
| 58 | + raise HTTPErrorResponseException(response.status, |
| 59 | + f"Unexpected response code: {response.status}") |
| 60 | + except Exception as e: |
| 61 | + if response and not (isinstance(e, HTTPErrorResponseException) and response.status == 429): |
| 62 | + errors += 1 |
| 63 | + self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}") |
| 64 | + if errors >= 3 or isinstance(e, CohortTooLargeException): |
| 65 | + raise e |
| 66 | + time.sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000) |
| 67 | + |
| 68 | + def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse: |
| 69 | + headers = { |
| 70 | + 'Authorization': f'Basic {self._get_basic_auth()}', |
| 71 | + 'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" |
| 72 | + } |
| 73 | + conn = self._connection_pool.acquire() |
| 74 | + try: |
| 75 | + url = f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}' |
| 76 | + if last_modified is not None: |
| 77 | + url += f'&lastModified={last_modified}' |
| 78 | + response = conn.request('GET', url, headers=headers) |
| 79 | + return response |
| 80 | + finally: |
| 81 | + self._connection_pool.release(conn) |
| 82 | + |
| 83 | + def _get_basic_auth(self) -> str: |
| 84 | + credentials = f'{self.api_key}:{self.secret_key}' |
| 85 | + return base64.b64encode(credentials.encode('utf-8')).decode('utf-8') |
| 86 | + |
| 87 | + def __setup_connection_pool(self): |
| 88 | + scheme, _, host = self.server_url.split('/', 3) |
| 89 | + timeout = 10 |
| 90 | + self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout, |
| 91 | + scheme=scheme) |
0 commit comments