Skip to content
8 changes: 7 additions & 1 deletion ads/aqua/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from logging import getLogger

from ads import logger, set_auth
from ads.aqua.client.client import AsyncClient, Client
from ads.aqua.client.client import (
AsyncClient,
Client,
HttpxOCIAuth,
get_async_httpx_client,
get_httpx_client,
)
from ads.aqua.common.utils import fetch_service_compartment
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION

Expand Down
59 changes: 48 additions & 11 deletions ads/aqua/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,23 @@
logger = logging.getLogger(__name__)


class OCIAuth(httpx.Auth):
class HttpxOCIAuth(httpx.Auth):
"""
Custom HTTPX authentication class that uses the OCI Signer for request signing.

Attributes:
signer (oci.signer.Signer): The OCI signer used to sign requests.
"""

def __init__(self, signer: oci.signer.Signer):
def __init__(self, signer: Optional[oci.signer.Signer] = None):
"""
Initialize the OCIAuth instance.
Initialize the HttpxOCIAuth instance.

Args:
signer (oci.signer.Signer): The OCI signer to use for signing requests.
"""
self.signer = signer

self.signer = signer or authutil.default_signer().get("signer")

def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
"""
Expand Down Expand Up @@ -256,7 +257,7 @@ def __init__(
auth = auth or authutil.default_signer()
if not callable(auth.get("signer")):
raise ValueError("Auth object must have a 'signer' callable attribute.")
self.auth = OCIAuth(auth["signer"])
self.auth = HttpxOCIAuth(auth["signer"])

logger.debug(
f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, "
Expand Down Expand Up @@ -352,7 +353,7 @@ def __init__(self, *args, **kwargs) -> None:
**kwargs: Keyword arguments forwarded to BaseClient.
"""
super().__init__(*args, **kwargs)
self._client = httpx.Client(timeout=self.timeout)
self._client = httpx.Client(timeout=self.timeout, auth=self.auth)

def is_closed(self) -> bool:
return self._client.is_closed
Expand Down Expand Up @@ -400,7 +401,6 @@ def _request(
response = self._client.post(
self.endpoint,
headers=self._prepare_headers(stream=False, headers=headers),
auth=self.auth,
json=payload,
)
logger.debug(f"Received response with status code: {response.status_code}")
Expand Down Expand Up @@ -447,7 +447,6 @@ def _stream(
"POST",
self.endpoint,
headers=self._prepare_headers(stream=True, headers=headers),
auth=self.auth,
json={**payload, "stream": True},
) as response:
try:
Expand Down Expand Up @@ -581,7 +580,7 @@ def __init__(self, *args, **kwargs) -> None:
**kwargs: Keyword arguments forwarded to BaseClient.
"""
super().__init__(*args, **kwargs)
self._client = httpx.AsyncClient(timeout=self.timeout)
self._client = httpx.AsyncClient(timeout=self.timeout, auth=self.auth)

def is_closed(self) -> bool:
return self._client.is_closed
Expand Down Expand Up @@ -637,7 +636,6 @@ async def _request(
response = await self._client.post(
self.endpoint,
headers=self._prepare_headers(stream=False, headers=headers),
auth=self.auth,
json=payload,
)
logger.debug(f"Received response with status code: {response.status_code}")
Expand Down Expand Up @@ -683,7 +681,6 @@ async def _stream(
"POST",
self.endpoint,
headers=self._prepare_headers(stream=True, headers=headers),
auth=self.auth,
json={**payload, "stream": True},
) as response:
try:
Expand Down Expand Up @@ -797,3 +794,43 @@ async def embeddings(
logger.debug(f"Generating embeddings with input: {input}, payload: {payload}")
payload = {**(payload or {}), "input": input}
return await self._request(payload=payload, headers=headers)


def get_httpx_client(**kwargs: Any) -> httpx.Client:
"""
Creates and returns a synchronous httpx Client configured with OCI authentication signer based
the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE.
More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html

Parameters
----------
**kwargs : Any
Keyword arguments supported by httpx.Client

Returns
-------
Client
A configured synchronous httpx Client instance.
"""
kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth()
return httpx.Client(**kwargs)


def get_async_httpx_client(**kwargs: Any) -> httpx.AsyncClient:
"""
Creates and returns a synchronous httpx Client configured with OCI authentication signer based
the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE.
More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html

Parameters
----------
**kwargs : Any
Keyword arguments supported by httpx.Client

Returns
-------
AsyncClient
A configured asynchronous httpx AsyncClient instance.
"""
kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth()
return httpx.AsyncClient(**kwargs)
31 changes: 20 additions & 11 deletions ads/common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,15 @@ def set_auth(
auth: Optional[str] = AuthType.API_KEY,
oci_config_location: Optional[str] = DEFAULT_LOCATION,
profile: Optional[str] = DEFAULT_PROFILE,
config: Optional[Dict] = {"region": os.environ["OCI_RESOURCE_REGION"]}
if os.environ.get("OCI_RESOURCE_REGION")
else {},
config: Optional[Dict] = (
{"region": os.environ["OCI_RESOURCE_REGION"]}
if os.environ.get("OCI_RESOURCE_REGION")
else {}
),
signer: Optional[Any] = None,
signer_callable: Optional[Callable] = None,
signer_kwargs: Optional[Dict] = {},
client_kwargs: Optional[Dict] = {},
signer_kwargs: Optional[Dict] = None,
client_kwargs: Optional[Dict] = None,
) -> None:
"""
Sets the default authentication type.
Expand Down Expand Up @@ -195,6 +197,9 @@ def set_auth(
>>> # instance principals authentication dictionary created based on callable with kwargs parameters:
>>> ads.set_auth(signer_callable=signer_callable, signer_kwargs=signer_kwargs)
"""
signer_kwargs = signer_kwargs or {}
client_kwargs = client_kwargs or {}

auth_state = AuthState()

valid_auth_keys = AuthFactory.classes.keys()
Expand Down Expand Up @@ -258,9 +263,11 @@ def api_keys(
"""
signer_args = dict(
oci_config=oci_config if isinstance(oci_config, Dict) else {},
oci_config_location=oci_config
if isinstance(oci_config, str)
else os.path.expanduser(DEFAULT_LOCATION),
oci_config_location=(
oci_config
if isinstance(oci_config, str)
else os.path.expanduser(DEFAULT_LOCATION)
),
oci_key_profile=profile,
client_kwargs=client_kwargs,
)
Expand Down Expand Up @@ -334,9 +341,11 @@ def security_token(
"""
signer_args = dict(
oci_config=oci_config if isinstance(oci_config, Dict) else {},
oci_config_location=oci_config
if isinstance(oci_config, str)
else os.path.expanduser(DEFAULT_LOCATION),
oci_config_location=(
oci_config
if isinstance(oci_config, str)
else os.path.expanduser(DEFAULT_LOCATION)
),
oci_key_profile=profile,
client_kwargs=client_kwargs,
)
Expand Down
42 changes: 42 additions & 0 deletions docs/source/user_guide/large_language_model/aqua_client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,45 @@ The following examples demonstrate how to perform the same operations using the
input=["one", "two"]
)
print(response)


HTTPX Client Integration with OCI Authentication
================================================

.. versionadded:: 2.13.1

The latest client release now includes streamlined support for OCI authentication with HTTPX. Our helper functions for creating synchronous and asynchronous HTTPX clients automatically configure authentication based on your default settings. Additionally, you can pass extra keyword arguments to further customize the HTTPX client (e.g., timeouts, proxies, etc.), making it fully compatible with OCI Model Deployment service and third-party libraries (e.g., the OpenAI client).

Usage
-----

**Synchronous HTTPX Client**

.. code-block:: python3

import ads

ads.set_auth(auth="security_token", profile="<replace-with-your-profile>")

client = ads.aqua.get_httpx_client(timeout=10.0)

response = client.post(
url="https://<MD_OCID>/predict",
json={
"model": "odsc-llm",
"prompt": "Tell me a joke."
},
)

response.raise_for_status()
json_response = response.json()

**Asynchronous HTTPX Client**

.. code-block:: python3

import ads

ads.set_auth(auth="security_token", profile="<replace-with-your-profile>")

async_client = client = ads.aqua.get_async_httpx_client(timeout=10.0)
1 change: 0 additions & 1 deletion tests/unitary/default_setup/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os
from mock import MagicMock
import pytest
from unittest import TestCase, mock

Expand Down
10 changes: 5 additions & 5 deletions tests/unitary/with_extras/aqua/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
BaseClient,
Client,
ExtendedRequestError,
OCIAuth,
HttpxOCIAuth,
_create_retry_decorator,
_retry_decorator,
_should_retry_exception,
)
from ads.common import auth as authutil


class TestOCIAuth:
"""Unit tests for OCIAuth class."""
class TestHttpxOCIAuth:
"""Unit tests for HttpxOCIAuth class."""

def setup_method(self):
self.signer_mock = Mock()
self.oci_auth = OCIAuth(self.signer_mock)
self.oci_auth = HttpxOCIAuth(self.signer_mock)

def test_auth_flow(self):
"""Ensures that the auth_flow signs the request correctly."""
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_init(self):
assert self.base_client.retries == self.retries
assert self.base_client.backoff_factor == self.backoff_factor
assert self.base_client.timeout == self.timeout
assert isinstance(self.base_client.auth, OCIAuth)
assert isinstance(self.base_client.auth, HttpxOCIAuth)

def test_init_default_auth(self):
"""Ensures that default auth is used when auth is None."""
Expand Down