Skip to content
50 changes: 37 additions & 13 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
cast,
Any,
AsyncIterable,
Callable,
Optional,
Set,
Sequence,
Expand Down Expand Up @@ -97,18 +98,24 @@
)
from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE
from google.cloud.bigtable.data._async._swappable_channel import (
AsyncSwappableChannel,
AsyncSwappableChannel as SwappableChannelType,
)
from google.cloud.bigtable.data._async.metrics_interceptor import (
AsyncBigtableMetricsInterceptor as MetricsInterceptorType,
)
else:
from typing import Iterable # noqa: F401
from grpc import insecure_channel
from grpc import intercept_channel
from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore
from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore
from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE
from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( # noqa: F401
SwappableChannel,
SwappableChannel as SwappableChannelType,
)
from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( # noqa: F401
BigtableMetricsInterceptor as MetricsInterceptorType,
)


if TYPE_CHECKING:
from google.cloud.bigtable.data._helpers import RowKeySamples
Expand Down Expand Up @@ -203,7 +210,7 @@ def __init__(
credentials = google.auth.credentials.AnonymousCredentials()
if project is None:
project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT

self._metrics_interceptor = MetricsInterceptorType()
# initialize client
ClientWithProject.__init__(
self,
Expand Down Expand Up @@ -257,12 +264,11 @@ def __init__(
stacklevel=2,
)

@CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"})
def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel:
def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannelType:
"""
This method is called by the gapic transport to create a grpc channel.

The init arguments passed down are captured in a partial used by AsyncSwappableChannel
The init arguments passed down are captured in a partial used by SwappableChannel
to create new channel instances in the future, as part of the channel refresh logic

Emulators always use an inseucre channel
Expand All @@ -273,12 +279,30 @@ def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel:
Returns:
a custom wrapped swappable channel
"""
create_channel_fn: Callable[[], Channel]
if self._emulator_host is not None:
# emulators use insecure channel
# Emulators use insecure channels
create_channel_fn = partial(insecure_channel, self._emulator_host)
else:
elif CrossSync.is_async:
# For async client, use the default create_channel.
create_channel_fn = partial(TransportType.create_channel, *args, **kwargs)
return AsyncSwappableChannel(create_channel_fn)
else:
# For sync client, wrap create_channel with interceptors.
def sync_create_channel_fn():
return intercept_channel(
TransportType.create_channel(*args, **kwargs),
self._metrics_interceptor,
)

create_channel_fn = sync_create_channel_fn

# Instantiate SwappableChannelType with the determined creation function.
new_channel = SwappableChannelType(create_channel_fn)
if CrossSync.is_async:
# Attach async interceptors to the channel instance itself.
new_channel._unary_unary_interceptors.append(self._metrics_interceptor)
new_channel._unary_stream_interceptors.append(self._metrics_interceptor)
return new_channel

@property
def universe_domain(self) -> str:
Expand Down Expand Up @@ -400,7 +424,7 @@ def _invalidate_channel_stubs(self):
self.transport._stubs = {}
self.transport._prep_wrapped_messages(self.client_info)

@CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"})
@CrossSync.convert
async def _manage_channel(
self,
refresh_interval_min: float = 60 * 35,
Expand All @@ -425,10 +449,10 @@ async def _manage_channel(
grace_period: time to allow previous channel to serve existing
requests before closing, in seconds
"""
if not isinstance(self.transport.grpc_channel, AsyncSwappableChannel):
if not isinstance(self.transport.grpc_channel, SwappableChannelType):
warnings.warn("Channel does not support auto-refresh.")
return
super_channel: AsyncSwappableChannel = self.transport.grpc_channel
super_channel: SwappableChannelType = self.transport.grpc_channel
first_refresh = self._channel_init_time + random.uniform(
refresh_interval_min, refresh_interval_max
)
Expand Down
78 changes: 78 additions & 0 deletions google/cloud/bigtable/data/_async/metrics_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from google.cloud.bigtable.data._cross_sync import CrossSync

if CrossSync.is_async:
from grpc.aio import UnaryUnaryClientInterceptor
from grpc.aio import UnaryStreamClientInterceptor
else:
from grpc import UnaryUnaryClientInterceptor
from grpc import UnaryStreamClientInterceptor


__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.metrics_interceptor"


@CrossSync.convert_class(sync_name="BigtableMetricsInterceptor")
class AsyncBigtableMetricsInterceptor(
UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor
):
"""
An async gRPC interceptor to add client metadata and print server metadata.
"""

@CrossSync.convert
async def intercept_unary_unary(self, continuation, client_call_details, request):
"""
Interceptor for unary rpcs:
- MutateRow
- CheckAndMutateRow
- ReadModifyWriteRow
"""
try:
call = await continuation(client_call_details, request)
return call
except Exception as rpc_error:
raise rpc_error

@CrossSync.convert
async def intercept_unary_stream(self, continuation, client_call_details, request):
"""
Interceptor for streaming rpcs:
- ReadRows
- MutateRows
- SampleRowKeys
"""
try:
return self._streaming_generator_wrapper(
await continuation(client_call_details, request)
)
except Exception as rpc_error:
# handle errors while intializing stream
raise rpc_error

@staticmethod
@CrossSync.convert
async def _streaming_generator_wrapper(call):
"""
Wrapped generator to be returned by intercept_unary_stream.
"""
try:
async for response in call:
yield response
except Exception as e:
# handle errors while processing stream
raise e
30 changes: 23 additions & 7 deletions google/cloud/bigtable/data/_sync_autogen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# This file is automatically generated by CrossSync. Do not edit manually.

from __future__ import annotations
from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING
from typing import cast, Any, Callable, Optional, Set, Sequence, TYPE_CHECKING
import abc
import time
import warnings
Expand Down Expand Up @@ -75,12 +75,18 @@
from google.cloud.bigtable.data._cross_sync import CrossSync
from typing import Iterable
from grpc import insecure_channel
from grpc import intercept_channel
from google.cloud.bigtable_v2.services.bigtable.transports import (
BigtableGrpcTransport as TransportType,
)
from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient
from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE
from google.cloud.bigtable.data._sync_autogen._swappable_channel import SwappableChannel
from google.cloud.bigtable.data._sync_autogen._swappable_channel import (
SwappableChannel as SwappableChannelType,
)
from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import (
BigtableMetricsInterceptor as MetricsInterceptorType,
)

if TYPE_CHECKING:
from google.cloud.bigtable.data._helpers import RowKeySamples
Expand Down Expand Up @@ -143,6 +149,7 @@ def __init__(
credentials = google.auth.credentials.AnonymousCredentials()
if project is None:
project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT
self._metrics_interceptor = MetricsInterceptorType()
ClientWithProject.__init__(
self,
credentials=credentials,
Expand Down Expand Up @@ -186,7 +193,7 @@ def __init__(
stacklevel=2,
)

def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel:
def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannelType:
"""This method is called by the gapic transport to create a grpc channel.

The init arguments passed down are captured in a partial used by SwappableChannel
Expand All @@ -199,11 +206,20 @@ def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel:
- **kwargs: keyword arguments passed by the gapic layer to create a new channel with
Returns:
a custom wrapped swappable channel"""
create_channel_fn: Callable[[], Channel]
if self._emulator_host is not None:
create_channel_fn = partial(insecure_channel, self._emulator_host)
else:
create_channel_fn = partial(TransportType.create_channel, *args, **kwargs)
return SwappableChannel(create_channel_fn)

def sync_create_channel_fn():
return intercept_channel(
TransportType.create_channel(*args, **kwargs),
self._metrics_interceptor,
)

create_channel_fn = sync_create_channel_fn
new_channel = SwappableChannelType(create_channel_fn)
return new_channel

@property
def universe_domain(self) -> str:
Expand Down Expand Up @@ -324,10 +340,10 @@ def _manage_channel(
between `refresh_interval_min` and `refresh_interval_max`
grace_period: time to allow previous channel to serve existing
requests before closing, in seconds"""
if not isinstance(self.transport.grpc_channel, SwappableChannel):
if not isinstance(self.transport.grpc_channel, SwappableChannelType):
warnings.warn("Channel does not support auto-refresh.")
return
super_channel: SwappableChannel = self.transport.grpc_channel
super_channel: SwappableChannelType = self.transport.grpc_channel
first_refresh = self._channel_init_time + random.uniform(
refresh_interval_min, refresh_interval_max
)
Expand Down
59 changes: 59 additions & 0 deletions google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is automatically generated by CrossSync. Do not edit manually.

from __future__ import annotations
from grpc import UnaryUnaryClientInterceptor
from grpc import UnaryStreamClientInterceptor


class BigtableMetricsInterceptor(
UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor
):
"""
An async gRPC interceptor to add client metadata and print server metadata.
"""

def intercept_unary_unary(self, continuation, client_call_details, request):
"""Interceptor for unary rpcs:
- MutateRow
- CheckAndMutateRow
- ReadModifyWriteRow"""
try:
call = continuation(client_call_details, request)
return call
except Exception as rpc_error:
raise rpc_error

def intercept_unary_stream(self, continuation, client_call_details, request):
"""Interceptor for streaming rpcs:
- ReadRows
- MutateRows
- SampleRowKeys"""
try:
return self._streaming_generator_wrapper(
continuation(client_call_details, request)
)
except Exception as rpc_error:
raise rpc_error

@staticmethod
def _streaming_generator_wrapper(call):
"""Wrapped generator to be returned by intercept_unary_stream."""
try:
for response in call:
yield response
except Exception as e:
raise e
19 changes: 12 additions & 7 deletions tests/system/data/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,28 @@ async def test_channel_refresh(self, table_id, instance_id, temp_rows):
async with client.get_table(instance_id, table_id) as table:
rows = await table.read_rows({})
channel_wrapper = client.transport.grpc_channel
first_channel = client.transport.grpc_channel._channel
first_channel = channel_wrapper._channel
assert len(rows) == 2
await CrossSync.sleep(2)
rows_after_refresh = await table.read_rows({})
assert len(rows_after_refresh) == 2
assert client.transport.grpc_channel is channel_wrapper
assert client.transport.grpc_channel._channel is not first_channel
# ensure gapic's logging interceptor is still active
updated_channel = channel_wrapper._channel
assert updated_channel is not first_channel
# ensure interceptors are kept (gapic's logging interceptor, and metric interceptor)
if CrossSync.is_async:
interceptors = (
client.transport.grpc_channel._channel._unary_unary_interceptors
)
assert GapicInterceptor in [type(i) for i in interceptors]
unary_interceptors = updated_channel._unary_unary_interceptors
assert len(unary_interceptors) == 2
assert GapicInterceptor in [type(i) for i in unary_interceptors]
assert client._metrics_interceptor in unary_interceptors
stream_interceptors = updated_channel._unary_stream_interceptors
assert len(stream_interceptors) == 1
assert client._metrics_interceptor in stream_interceptors
else:
assert isinstance(
client.transport._logged_channel._interceptor, GapicInterceptor
)
assert updated_channel._interceptor == client._metrics_interceptor
finally:
await client.close()

Expand Down
Loading
Loading