diff --git a/ads/aqua/__init__.py b/ads/aqua/__init__.py index afb4f65a0..f2ba2c38b 100644 --- a/ads/aqua/__init__.py +++ b/ads/aqua/__init__.py @@ -7,7 +7,13 @@ from logging import getLogger from ads import logger, set_auth -from ads.aqua.client.client import AsyncClient, Client +from ads.aqua.client.client import ( + AsyncClient, + Client, + HttpxOCIAuth, + get_async_httpx_client, + get_httpx_client, +) from ads.aqua.common.utils import fetch_service_compartment from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION diff --git a/ads/aqua/client/client.py b/ads/aqua/client/client.py index 9ffb12a21..841779aaa 100644 --- a/ads/aqua/client/client.py +++ b/ads/aqua/client/client.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) -class OCIAuth(httpx.Auth): +class HttpxOCIAuth(httpx.Auth): """ Custom HTTPX authentication class that uses the OCI Signer for request signing. @@ -59,14 +59,15 @@ class OCIAuth(httpx.Auth): signer (oci.signer.Signer): The OCI signer used to sign requests. """ - def __init__(self, signer: oci.signer.Signer): + def __init__(self, signer: Optional[oci.signer.Signer] = None): """ - Initialize the OCIAuth instance. + Initialize the HttpxOCIAuth instance. Args: signer (oci.signer.Signer): The OCI signer to use for signing requests. """ - self.signer = signer + + self.signer = signer or authutil.default_signer().get("signer") def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]: """ @@ -256,7 +257,7 @@ def __init__( auth = auth or authutil.default_signer() if not callable(auth.get("signer")): raise ValueError("Auth object must have a 'signer' callable attribute.") - self.auth = OCIAuth(auth["signer"]) + self.auth = HttpxOCIAuth(auth["signer"]) logger.debug( f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, " @@ -352,7 +353,7 @@ def __init__(self, *args, **kwargs) -> None: **kwargs: Keyword arguments forwarded to BaseClient. """ super().__init__(*args, **kwargs) - self._client = httpx.Client(timeout=self.timeout) + self._client = httpx.Client(timeout=self.timeout, auth=self.auth) def is_closed(self) -> bool: return self._client.is_closed @@ -400,7 +401,6 @@ def _request( response = self._client.post( self.endpoint, headers=self._prepare_headers(stream=False, headers=headers), - auth=self.auth, json=payload, ) logger.debug(f"Received response with status code: {response.status_code}") @@ -447,7 +447,6 @@ def _stream( "POST", self.endpoint, headers=self._prepare_headers(stream=True, headers=headers), - auth=self.auth, json={**payload, "stream": True}, ) as response: try: @@ -581,7 +580,7 @@ def __init__(self, *args, **kwargs) -> None: **kwargs: Keyword arguments forwarded to BaseClient. """ super().__init__(*args, **kwargs) - self._client = httpx.AsyncClient(timeout=self.timeout) + self._client = httpx.AsyncClient(timeout=self.timeout, auth=self.auth) def is_closed(self) -> bool: return self._client.is_closed @@ -637,7 +636,6 @@ async def _request( response = await self._client.post( self.endpoint, headers=self._prepare_headers(stream=False, headers=headers), - auth=self.auth, json=payload, ) logger.debug(f"Received response with status code: {response.status_code}") @@ -683,7 +681,6 @@ async def _stream( "POST", self.endpoint, headers=self._prepare_headers(stream=True, headers=headers), - auth=self.auth, json={**payload, "stream": True}, ) as response: try: @@ -797,3 +794,43 @@ async def embeddings( logger.debug(f"Generating embeddings with input: {input}, payload: {payload}") payload = {**(payload or {}), "input": input} return await self._request(payload=payload, headers=headers) + + +def get_httpx_client(**kwargs: Any) -> httpx.Client: + """ + Creates and returns a synchronous httpx Client configured with OCI authentication signer based + the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE. + More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html + + Parameters + ---------- + **kwargs : Any + Keyword arguments supported by httpx.Client + + Returns + ------- + Client + A configured synchronous httpx Client instance. + """ + kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth() + return httpx.Client(**kwargs) + + +def get_async_httpx_client(**kwargs: Any) -> httpx.AsyncClient: + """ + Creates and returns a synchronous httpx Client configured with OCI authentication signer based + the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE. + More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html + + Parameters + ---------- + **kwargs : Any + Keyword arguments supported by httpx.Client + + Returns + ------- + AsyncClient + A configured asynchronous httpx AsyncClient instance. + """ + kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth() + return httpx.AsyncClient(**kwargs) diff --git a/ads/common/auth.py b/ads/common/auth.py index 6ab7712de..b42dc89f4 100644 --- a/ads/common/auth.py +++ b/ads/common/auth.py @@ -88,13 +88,15 @@ def set_auth( auth: Optional[str] = AuthType.API_KEY, oci_config_location: Optional[str] = DEFAULT_LOCATION, profile: Optional[str] = DEFAULT_PROFILE, - config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]} - if os.environ.get("OCI_RESOURCE_REGION") - else {}, + config: Optional[Dict] = ( + {"region": os.environ["OCI_RESOURCE_REGION"]} + if os.environ.get("OCI_RESOURCE_REGION") + else {} + ), signer: Optional[Any] = None, signer_callable: Optional[Callable] = None, - signer_kwargs: Optional[Dict] = {}, - client_kwargs: Optional[Dict] = {}, + signer_kwargs: Optional[Dict] = None, + client_kwargs: Optional[Dict] = None, ) -> None: """ Sets the default authentication type. @@ -195,6 +197,9 @@ def set_auth( >>> # instance principals authentication dictionary created based on callable with kwargs parameters: >>> ads.set_auth(signer_callable=signer_callable, signer_kwargs=signer_kwargs) """ + signer_kwargs = signer_kwargs or {} + client_kwargs = client_kwargs or {} + auth_state = AuthState() valid_auth_keys = AuthFactory.classes.keys() @@ -258,9 +263,11 @@ def api_keys( """ signer_args = dict( oci_config=oci_config if isinstance(oci_config, Dict) else {}, - oci_config_location=oci_config - if isinstance(oci_config, str) - else os.path.expanduser(DEFAULT_LOCATION), + oci_config_location=( + oci_config + if isinstance(oci_config, str) + else os.path.expanduser(DEFAULT_LOCATION) + ), oci_key_profile=profile, client_kwargs=client_kwargs, ) @@ -334,9 +341,11 @@ def security_token( """ signer_args = dict( oci_config=oci_config if isinstance(oci_config, Dict) else {}, - oci_config_location=oci_config - if isinstance(oci_config, str) - else os.path.expanduser(DEFAULT_LOCATION), + oci_config_location=( + oci_config + if isinstance(oci_config, str) + else os.path.expanduser(DEFAULT_LOCATION) + ), oci_key_profile=profile, client_kwargs=client_kwargs, ) diff --git a/docs/source/user_guide/large_language_model/aqua_client.rst b/docs/source/user_guide/large_language_model/aqua_client.rst index ad3d7038e..45bf40578 100644 --- a/docs/source/user_guide/large_language_model/aqua_client.rst +++ b/docs/source/user_guide/large_language_model/aqua_client.rst @@ -129,3 +129,45 @@ The following examples demonstrate how to perform the same operations using the input=["one", "two"] ) print(response) + + +HTTPX Client Integration with OCI Authentication +================================================ + +.. versionadded:: 2.13.1 + +The latest client release now includes streamlined support for OCI authentication with HTTPX. Our helper functions for creating synchronous and asynchronous HTTPX clients automatically configure authentication based on your default settings. Additionally, you can pass extra keyword arguments to further customize the HTTPX client (e.g., timeouts, proxies, etc.), making it fully compatible with OCI Model Deployment service and third-party libraries (e.g., the OpenAI client). + +Usage +----- + +**Synchronous HTTPX Client** + +.. code-block:: python3 + + import ads + + ads.set_auth(auth="security_token", profile="") + + client = ads.aqua.get_httpx_client(timeout=10.0) + + response = client.post( + url="https:///predict", + json={ + "model": "odsc-llm", + "prompt": "Tell me a joke." + }, + ) + + response.raise_for_status() + json_response = response.json() + +**Asynchronous HTTPX Client** + +.. code-block:: python3 + + import ads + + ads.set_auth(auth="security_token", profile="") + + async_client = client = ads.aqua.get_async_httpx_client(timeout=10.0) diff --git a/tests/unitary/default_setup/auth/test_auth.py b/tests/unitary/default_setup/auth/test_auth.py index b9f695b2c..2c00b48c7 100644 --- a/tests/unitary/default_setup/auth/test_auth.py +++ b/tests/unitary/default_setup/auth/test_auth.py @@ -4,7 +4,6 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os -from mock import MagicMock import pytest from unittest import TestCase, mock diff --git a/tests/unitary/with_extras/aqua/test_client.py b/tests/unitary/with_extras/aqua/test_client.py index 016f92faa..197ae4e2a 100644 --- a/tests/unitary/with_extras/aqua/test_client.py +++ b/tests/unitary/with_extras/aqua/test_client.py @@ -15,7 +15,7 @@ BaseClient, Client, ExtendedRequestError, - OCIAuth, + HttpxOCIAuth, _create_retry_decorator, _retry_decorator, _should_retry_exception, @@ -23,12 +23,12 @@ from ads.common import auth as authutil -class TestOCIAuth: - """Unit tests for OCIAuth class.""" +class TestHttpxOCIAuth: + """Unit tests for HttpxOCIAuth class.""" def setup_method(self): self.signer_mock = Mock() - self.oci_auth = OCIAuth(self.signer_mock) + self.oci_auth = HttpxOCIAuth(self.signer_mock) def test_auth_flow(self): """Ensures that the auth_flow signs the request correctly.""" @@ -226,7 +226,7 @@ def test_init(self): assert self.base_client.retries == self.retries assert self.base_client.backoff_factor == self.backoff_factor assert self.base_client.timeout == self.timeout - assert isinstance(self.base_client.auth, OCIAuth) + assert isinstance(self.base_client.auth, HttpxOCIAuth) def test_init_default_auth(self): """Ensures that default auth is used when auth is None."""