diff --git a/jbi/services/bugzilla.py b/jbi/services/bugzilla.py index 6d1b8a71..4358b5a2 100644 --- a/jbi/services/bugzilla.py +++ b/jbi/services/bugzilla.py @@ -17,7 +17,7 @@ BugzillaWebhooksResponse, ) -from .common import InstrumentedClient, ServiceHealth +from .common import ServiceHealth, instrument settings = environment.get_settings() @@ -28,6 +28,15 @@ class BugzillaClientError(Exception): """Errors raised by `BugzillaClient`.""" +instrumented_method = instrument( + prefix="bugzilla", + exceptions=( + BugzillaClientError, + requests.RequestException, + ), +) + + class BugzillaClient: """A wrapper around `requests` to interact with a Bugzilla REST API.""" @@ -57,6 +66,7 @@ def logged_in(self) -> bool: resp = self._call("GET", f"{self.base_url}/rest/whoami") return "id" in resp + @instrumented_method def get_bug(self, bugid) -> BugzillaBug: """Retrieve details about the specified bug id.""" # https://bugzilla.readthedocs.io/en/latest/api/core/v1/bug.html#rest-single-bug @@ -77,6 +87,7 @@ def get_bug(self, bugid) -> BugzillaBug: bug = bug.copy(update={"comment": found}) return bug + @instrumented_method def get_comments(self, bugid) -> list[BugzillaComment]: """Retrieve the list of comments of the specified bug id.""" # https://bugzilla.readthedocs.io/en/latest/api/core/v1/comment.html#rest-comments @@ -89,6 +100,7 @@ def get_comments(self, bugid) -> list[BugzillaComment]: ) return parse_obj_as(list[BugzillaComment], comments) + @instrumented_method def update_bug(self, bugid, **fields) -> BugzillaBug: """Update the specified fields of the specified bug.""" # https://bugzilla.readthedocs.io/en/latest/api/core/v1/bug.html#rest-update-bug @@ -101,6 +113,7 @@ def update_bug(self, bugid, **fields) -> BugzillaBug: ) return parsed.bugs[0] + @instrumented_method def list_webhooks(self): """List the currently configured webhooks, including their status.""" url = f"{self.base_url}/rest/webhooks/list" @@ -116,23 +129,9 @@ def list_webhooks(self): @lru_cache(maxsize=1) def get_client(): """Get bugzilla service""" - bugzilla_client = BugzillaClient( + return BugzillaClient( settings.bugzilla_base_url, api_key=str(settings.bugzilla_api_key) ) - return InstrumentedClient( - wrapped=bugzilla_client, - prefix="bugzilla", - methods=( - "get_bug", - "get_comments", - "update_bugs", - "list_webhooks", - ), - exceptions=( - BugzillaClientError, - requests.RequestException, - ), - ) def check_health() -> ServiceHealth: diff --git a/jbi/services/common.py b/jbi/services/common.py index 691a9685..67f9837f 100644 --- a/jbi/services/common.py +++ b/jbi/services/common.py @@ -4,6 +4,8 @@ InstrumentedClient: wraps service clients so that we can track their usage """ import logging +from functools import wraps +from typing import Sequence, Type import backoff from statsd.defaults.env import statsd @@ -18,33 +20,26 @@ ServiceHealth = dict[str, bool] -class InstrumentedClient: - """This class wraps an object and increments a counter every time - the specified methods are called, and times their execution. - It retries the methods if the specified exceptions are raised. +def instrument(prefix: str, exceptions: Sequence[Type[Exception]]): + """This decorator wraps a function such that it increments a counter every + time it is called and times its execution. It retries the function if the + specified exceptions are raised. """ - def __init__(self, wrapped, prefix, methods, exceptions): - self.wrapped = wrapped - self.prefix = prefix - self.methods = methods - self.exceptions = exceptions - - def __getattr__(self, attr): - if attr not in self.methods: - return getattr(self.wrapped, attr) - + def decorator(func): + @wraps(func) @backoff.on_exception( backoff.expo, - self.exceptions, + exceptions, max_tries=settings.max_retries + 1, ) - def wrapped_func(*args, **kwargs): + def wrapper(*args, **kwargs): # Increment the call counter. - statsd.incr(f"jbi.{self.prefix}.methods.{attr}.count") + statsd.incr(f"jbi.{prefix}.methods.{func.__name__}.count") # Time its execution. - with statsd.timer(f"jbi.{self.prefix}.methods.{attr}.timer"): - return getattr(self.wrapped, attr)(*args, **kwargs) + with statsd.timer(f"jbi.{prefix}.methods.{func.__name__}.timer"): + return func(*args, **kwargs) + + return wrapper - # The method was not called yet. - return wrapped_func + return decorator diff --git a/jbi/services/jira.py b/jbi/services/jira.py index 39e79ce2..8bd8b992 100644 --- a/jbi/services/jira.py +++ b/jbi/services/jira.py @@ -15,7 +15,7 @@ from jbi import Operation, environment from jbi.models import ActionContext, BugzillaBug -from .common import InstrumentedClient, ServiceHealth +from .common import ServiceHealth, instrument if TYPE_CHECKING: from jbi.models import Actions @@ -27,29 +27,30 @@ JIRA_DESCRIPTION_CHAR_LIMIT = 32767 +instrumented_method = instrument(prefix="jira", exceptions=(errors.ApiError,)) + + +class JiraClient(Jira): + """Adapted Atlassian Jira client that wraps methods in our instrumentation + decorator. + """ + + update_issue_field = instrumented_method(Jira.update_issue_field) + set_issue_status = instrumented_method(Jira.set_issue_status) + issue_add_comment = instrumented_method(Jira.issue_add_comment) + create_issue = instrumented_method(Jira.create_issue) + @lru_cache(maxsize=1) def get_client(): """Get atlassian Jira Service""" - jira_client = Jira( + return JiraClient( url=settings.jira_base_url, username=settings.jira_username, password=settings.jira_api_key, # package calls this param 'password' but actually expects an api key cloud=True, # we run against an instance of Jira cloud ) - return InstrumentedClient( - wrapped=jira_client, - prefix="jira", - methods=( - "update_issue_field", - "set_issue_status", - "issue_add_comment", - "create_issue", - ), - exceptions=(errors.ApiError,), - ) - def fetch_visible_projects() -> list[dict]: """Return list of projects that are visible with the configured Jira credentials""" diff --git a/tests/conftest.py b/tests/conftest.py index ade98d58..725eee0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,12 @@ ) +@pytest.fixture(autouse=True) +def mocked_statsd(): + with mock.patch("jbi.services.common.statsd") as _mocked_statsd: + yield _mocked_statsd + + @pytest.fixture def anon_client(): """A test client with no authorization.""" @@ -66,7 +72,7 @@ def mocked_jira(request): yield None jira.get_client.cache_clear() else: - with mock.patch("jbi.services.jira.Jira") as mocked_jira: + with mock.patch("jbi.services.jira.JiraClient") as mocked_jira: yield mocked_jira() jira.get_client.cache_clear() diff --git a/tests/unit/services/test_bugzilla.py b/tests/unit/services/test_bugzilla.py index f0968735..cf9cccfb 100644 --- a/tests/unit/services/test_bugzilla.py +++ b/tests/unit/services/test_bugzilla.py @@ -8,13 +8,18 @@ from jbi.services.bugzilla import BugzillaClientError, get_client -def test_timer_is_used_on_bugzilla_get_comments(): +@pytest.mark.no_mocked_bugzilla +def test_timer_is_used_on_bugzilla_get_comments(mocked_responses, mocked_statsd): bugzilla_client = get_client() - - with mock.patch("jbi.services.common.statsd") as mocked: - bugzilla_client.get_comments([]) - - mocked.timer.assert_called_with("jbi.bugzilla.methods.get_comments.timer") + mocked_responses.add( + "GET", + f"{get_settings().bugzilla_base_url}/rest/bug/42/comment", + json={ + "bugs": {"42": {"comments": []}}, + }, + ) + bugzilla_client.get_comments(42) + mocked_statsd.timer.assert_called_with("jbi.bugzilla.methods.get_comments.timer") @pytest.mark.no_mocked_bugzilla diff --git a/tests/unit/services/test_jira.py b/tests/unit/services/test_jira.py index 981cff2e..d55b9a22 100644 --- a/tests/unit/services/test_jira.py +++ b/tests/unit/services/test_jira.py @@ -1,5 +1,4 @@ import json -from unittest import mock import pytest import responses @@ -8,22 +7,48 @@ from jbi.services import jira -def test_counter_is_incremented_on_jira_create_issue(): - jira_client = jira.get_client() - - with mock.patch("jbi.services.common.statsd") as mocked: - jira_client.create_issue({}) - - mocked.incr.assert_called_with("jbi.jira.methods.create_issue.count") +@pytest.mark.no_mocked_jira +def test_jira_create_issue_is_instrumented( + mocked_responses, context_create_example, mocked_statsd +): + url = f"{get_settings().jira_base_url}rest/api/2/project/{context_create_example.jira.project}/components" + mocked_responses.add( + responses.GET, + url, + json=[ + { + "id": "10000", + "name": "Component 1", + }, + { + "id": "42", + "name": "Remote Settings", + }, + ], + ) + url = f"{get_settings().jira_base_url}rest/api/2/issue" + mocked_responses.add( + responses.POST, + url, + json={ + "id": "10000", + "key": "ED-24", + }, + ) -def test_timer_is_used_on_jira_create_issue(): + jira.create_jira_issue( + context_create_example, + "Description", + sync_whiteboard_labels=False, + components=["Remote Settings"], + ) jira_client = jira.get_client() - with mock.patch("jbi.services.common.statsd") as mocked: - jira_client.create_issue({}) + jira_client.create_issue({}) - mocked.timer.assert_called_with("jbi.jira.methods.create_issue.timer") + mocked_statsd.incr.assert_called_with("jbi.jira.methods.create_issue.count") + mocked_statsd.timer.assert_called_with("jbi.jira.methods.create_issue.timer") @pytest.mark.no_mocked_jira