Skip to content

Commit d8c62c4

Browse files
authored
feat: support cohort targeting for local evaluation (#47)
1 parent 4503a60 commit d8c62c4

25 files changed

+1255
-70
lines changed

.github/workflows/test-arm.yml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,21 @@ on: [pull_request]
44
jobs:
55
aarch_job:
66
runs-on: ubuntu-latest
7-
name: Test on ubuntu aarch64
7+
environment: Unit Test
8+
name: Test on Ubuntu aarch64
89
steps:
9-
- uses: actions/checkout@v3
10-
- uses: uraimo/run-on-arch-action@v2
11-
name: Run Unit Test
10+
- name: Checkout source code
11+
uses: actions/checkout@v3
12+
13+
- name: Set up and run unit test on aarch64
14+
uses: uraimo/run-on-arch-action@v2
1215
id: runcmd
1316
with:
17+
env: |
18+
API_KEY: ${{ secrets.API_KEY }}
19+
SECRET_KEY: ${{ secrets.SECRET_KEY }}
20+
EU_API_KEY: ${{ secrets.EU_API_KEY }}
21+
EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }}
1422
arch: aarch64
1523
distro: ubuntu20.04
1624
githubToken: ${{ github.token }}

.github/workflows/test.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on: [pull_request]
55
jobs:
66
test:
77
runs-on: ubuntu-latest
8+
environment: Unit Test
89
strategy:
910
matrix:
1011
python-version: [ "3.7" ]
@@ -19,8 +20,12 @@ jobs:
1920
cache: 'pip'
2021

2122
- name: Install requirements
22-
run: pip install -r requirements.txt
23-
pip install -r requirements-dev.txt
23+
run: pip install -r requirements.txt && pip install -r requirements-dev.txt
2424

2525
- name: Unit Test
26+
env:
27+
API_KEY: ${{ secrets.API_KEY }}
28+
SECRET_KEY: ${{ secrets.SECRET_KEY }}
29+
EU_API_KEY: ${{ secrets.EU_API_KEY }}
30+
EU_SECRET_KEY: ${{ secrets.EU_SECRET_KEY }}
2631
run: python -m unittest discover -s ./tests -p '*_test.py'

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ user = User(
7777
variants = experiment.evaluate(user)
7878
```
7979

80+
# Running unit tests suite
81+
To setup for running test on local, create a `.env` file with following
82+
contents, and replace `{API_KEY}` and `{SECRET_KEY}` (or `{EU_API_KEY}` and `{EU_SECRET_KEY}` for EU data center) for the project in test:
83+
84+
```
85+
API_KEY={API_KEY}
86+
SECRET_KEY={SECRET_KEY}
87+
```
88+
8089
## More Information
8190
Please visit our :100:[Developer Center](https://www.docs.developers.amplitude.com/experiment/sdks/python-sdk/) for more instructions on using our the SDK.
8291

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
parameterized~=0.9.0
2+
python-dotenv~=0.21.1

src/amplitude_experiment/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
from .cookie import AmplitudeCookie
1313
from .local.client import LocalEvaluationClient
1414
from .local.config import LocalEvaluationConfig
15+
from .local.config import ServerZone
1516
from .assignment import AssignmentConfig
17+
from .cohort.cohort_sync_config import CohortSyncConfig
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from dataclasses import dataclass, field
2+
from typing import ClassVar, Set
3+
4+
USER_GROUP_TYPE: ClassVar[str] = "User"
5+
6+
7+
@dataclass
8+
class Cohort:
9+
id: str
10+
last_modified: int
11+
size: int
12+
member_ids: Set[str]
13+
group_type: str = field(default=USER_GROUP_TYPE)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import time
2+
import logging
3+
import base64
4+
import json
5+
from http.client import HTTPResponse
6+
from typing import Optional
7+
from ..version import __version__
8+
9+
from .cohort import Cohort
10+
from ..connection_pool import HTTPConnectionPool
11+
from ..exception import HTTPErrorResponseException, CohortTooLargeException
12+
13+
COHORT_REQUEST_RETRY_DELAY_MILLIS = 100
14+
15+
16+
class CohortDownloadApi:
17+
18+
def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort:
19+
raise NotImplementedError
20+
21+
22+
class DirectCohortDownloadApi(CohortDownloadApi):
23+
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, server_url: str, logger: logging.Logger):
24+
super().__init__()
25+
self.api_key = api_key
26+
self.secret_key = secret_key
27+
self.max_cohort_size = max_cohort_size
28+
self.server_url = server_url
29+
self.logger = logger
30+
self.__setup_connection_pool()
31+
32+
def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None:
33+
self.logger.debug(f"getCohortMembers({cohort_id}): start")
34+
errors = 0
35+
while True:
36+
response = None
37+
try:
38+
last_modified = None if cohort is None else cohort.last_modified
39+
response = self._get_cohort_members_request(cohort_id, last_modified)
40+
self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}")
41+
if response.status == 200:
42+
cohort_info = json.loads(response.read().decode("utf8"))
43+
self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}")
44+
return Cohort(
45+
id=cohort_info['cohortId'],
46+
last_modified=cohort_info['lastModified'],
47+
size=cohort_info['size'],
48+
member_ids=set(cohort_info['memberIds']),
49+
group_type=cohort_info['groupType'],
50+
)
51+
elif response.status == 204:
52+
self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified")
53+
return
54+
elif response.status == 413:
55+
raise CohortTooLargeException(
56+
f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}")
57+
elif response.status != 202:
58+
raise HTTPErrorResponseException(response.status,
59+
f"Unexpected response code: {response.status}")
60+
except Exception as e:
61+
if response and not (isinstance(e, HTTPErrorResponseException) and response.status == 429):
62+
errors += 1
63+
self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}")
64+
if errors >= 3 or isinstance(e, CohortTooLargeException):
65+
raise e
66+
time.sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000)
67+
68+
def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse:
69+
headers = {
70+
'Authorization': f'Basic {self._get_basic_auth()}',
71+
'X-Amp-Exp-Library': f"experiment-python-server/{__version__}"
72+
}
73+
conn = self._connection_pool.acquire()
74+
try:
75+
url = f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}'
76+
if last_modified is not None:
77+
url += f'&lastModified={last_modified}'
78+
response = conn.request('GET', url, headers=headers)
79+
return response
80+
finally:
81+
self._connection_pool.release(conn)
82+
83+
def _get_basic_auth(self) -> str:
84+
credentials = f'{self.api_key}:{self.secret_key}'
85+
return base64.b64encode(credentials.encode('utf-8')).decode('utf-8')
86+
87+
def __setup_connection_pool(self):
88+
scheme, _, host = self.server_url.split('/', 3)
89+
timeout = 10
90+
self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout,
91+
scheme=scheme)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import logging
2+
from typing import Dict, Set
3+
from concurrent.futures import ThreadPoolExecutor, Future, as_completed
4+
import threading
5+
6+
from .cohort import Cohort
7+
from .cohort_download_api import CohortDownloadApi
8+
from .cohort_storage import CohortStorage
9+
from ..exception import CohortsDownloadException
10+
11+
12+
class CohortLoader:
13+
def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage):
14+
self.cohort_download_api = cohort_download_api
15+
self.cohort_storage = cohort_storage
16+
self.jobs: Dict[str, Future] = {}
17+
self.lock_jobs = threading.Lock()
18+
self.executor = ThreadPoolExecutor(
19+
max_workers=32,
20+
thread_name_prefix='CohortLoaderExecutor'
21+
)
22+
23+
def load_cohort(self, cohort_id: str) -> Future:
24+
with self.lock_jobs:
25+
if cohort_id not in self.jobs:
26+
future = self.executor.submit(self.__load_cohort_internal, cohort_id)
27+
future.add_done_callback(lambda f: self._remove_job(cohort_id))
28+
self.jobs[cohort_id] = future
29+
return self.jobs[cohort_id]
30+
31+
def _remove_job(self, cohort_id: str):
32+
if cohort_id in self.jobs:
33+
with self.lock_jobs:
34+
self.jobs.pop(cohort_id, None)
35+
36+
def download_cohort(self, cohort_id: str) -> Cohort:
37+
cohort = self.cohort_storage.get_cohort(cohort_id)
38+
return self.cohort_download_api.get_cohort(cohort_id, cohort)
39+
40+
def download_cohorts(self, cohort_ids: Set[str]) -> Future:
41+
def update_task(task_cohort_ids):
42+
errors = []
43+
futures = []
44+
for cohort_id in task_cohort_ids:
45+
future = self.load_cohort(cohort_id)
46+
futures.append(future)
47+
48+
for future in as_completed(futures):
49+
try:
50+
future.result()
51+
except Exception as e:
52+
cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None)
53+
if cohort_id:
54+
errors.append((cohort_id, e))
55+
56+
if errors:
57+
raise CohortsDownloadException(errors)
58+
59+
return self.executor.submit(update_task, cohort_ids)
60+
61+
def __load_cohort_internal(self, cohort_id):
62+
try:
63+
cohort = self.download_cohort(cohort_id)
64+
if cohort is not None:
65+
self.cohort_storage.put_cohort(cohort)
66+
except Exception as e:
67+
raise e
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Dict, Set, Optional
2+
from threading import RLock
3+
4+
from .cohort import Cohort, USER_GROUP_TYPE
5+
6+
7+
class CohortStorage:
8+
def get_cohort(self, cohort_id: str):
9+
raise NotImplementedError
10+
11+
def get_cohorts(self):
12+
raise NotImplementedError
13+
14+
def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]:
15+
raise NotImplementedError
16+
17+
def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]:
18+
raise NotImplementedError
19+
20+
def put_cohort(self, cohort_description: Cohort):
21+
raise NotImplementedError
22+
23+
def delete_cohort(self, group_type: str, cohort_id: str):
24+
raise NotImplementedError
25+
26+
def get_cohort_ids(self) -> Set[str]:
27+
raise NotImplementedError
28+
29+
30+
class InMemoryCohortStorage(CohortStorage):
31+
def __init__(self):
32+
self.lock = RLock()
33+
self.group_to_cohort_store: Dict[str, Set[str]] = {}
34+
self.cohort_store: Dict[str, Cohort] = {}
35+
36+
def get_cohort(self, cohort_id: str):
37+
with self.lock:
38+
return self.cohort_store.get(cohort_id)
39+
40+
def get_cohorts(self):
41+
return self.cohort_store.copy()
42+
43+
def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]:
44+
return self.get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids)
45+
46+
def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]:
47+
result = set()
48+
with self.lock:
49+
group_type_cohorts = self.group_to_cohort_store.get(group_type, {})
50+
for cohort_id in group_type_cohorts:
51+
members = self.cohort_store.get(cohort_id).member_ids
52+
if cohort_id in cohort_ids and group_name in members:
53+
result.add(cohort_id)
54+
return result
55+
56+
def put_cohort(self, cohort: Cohort):
57+
with self.lock:
58+
if cohort.group_type not in self.group_to_cohort_store:
59+
self.group_to_cohort_store[cohort.group_type] = set()
60+
self.group_to_cohort_store[cohort.group_type].add(cohort.id)
61+
self.cohort_store[cohort.id] = cohort
62+
63+
def delete_cohort(self, group_type: str, cohort_id: str):
64+
with self.lock:
65+
group_cohorts = self.group_to_cohort_store.get(group_type, {})
66+
if cohort_id in group_cohorts:
67+
group_cohorts.remove(cohort_id)
68+
if cohort_id in self.cohort_store:
69+
del self.cohort_store[cohort_id]
70+
71+
def get_cohort_ids(self):
72+
with self.lock:
73+
return set(self.cohort_store.keys())
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
DEFAULT_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com'
2+
EU_COHORT_SYNC_URL = 'https://cohort-v2.lab.eu.amplitude.com'
3+
4+
5+
class CohortSyncConfig:
6+
"""Experiment Cohort Sync Configuration
7+
This configuration is used to set up the cohort loader. The cohort loader is responsible for
8+
downloading cohorts from the server and storing them locally.
9+
Parameters:
10+
api_key (str): The project API Key
11+
secret_key (str): The project Secret Key
12+
max_cohort_size (int): The maximum cohort size that can be downloaded
13+
cohort_polling_interval_millis (int): The interval, in milliseconds, at which to poll for
14+
cohort updates, minimum 60000
15+
cohort_server_url (str): The server endpoint from which to request cohorts
16+
"""
17+
18+
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 2147483647,
19+
cohort_polling_interval_millis: int = 60000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL):
20+
self.api_key = api_key
21+
self.secret_key = secret_key
22+
self.max_cohort_size = max_cohort_size
23+
self.cohort_polling_interval_millis = max(cohort_polling_interval_millis, 60000)
24+
self.cohort_server_url = cohort_server_url

0 commit comments

Comments
 (0)