From d2175f1849c7d38f13aab6ea649b7739cc269f3f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Jul 2025 16:12:17 -0700 Subject: [PATCH 01/60] use replaceable channel wrapper --- google/cloud/bigtable/data/_async/client.py | 53 ++++++----- .../data/_async/replaceable_channel.py | 87 +++++++++++++++++++ 2 files changed, 113 insertions(+), 27 deletions(-) create mode 100644 google/cloud/bigtable/data/_async/replaceable_channel.py diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6ee21b554..2c6cd207d 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -92,7 +92,11 @@ from google.cloud.bigtable_v2.services.bigtable.transports import ( BigtableGrpcAsyncIOTransport as TransportType, ) + from google.cloud.bigtable_v2.services.bigtable.transports.grpc_asyncio import ( + _LoggingClientAIOInterceptor + ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE + from google.cloud.bigtable.data._async.replaceable_channel import _AsyncReplaceableChannel else: from typing import Iterable # noqa: F401 from grpc import insecure_channel @@ -182,7 +186,6 @@ def __init__( client_options = cast( Optional[client_options_lib.ClientOptions], client_options ) - custom_channel = None self._emulator_host = os.getenv(BIGTABLE_EMULATOR) if self._emulator_host is not None: warnings.warn( @@ -191,11 +194,11 @@ def __init__( stacklevel=2, ) # use insecure channel if emulator is set - custom_channel = insecure_channel(self._emulator_host) if credentials is None: credentials = google.auth.credentials.AnonymousCredentials() if project is None: project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT + # initialize client ClientWithProject.__init__( self, @@ -208,7 +211,7 @@ def __init__( client_options=client_options, client_info=self.client_info, transport=lambda *args, **kwargs: TransportType( - *args, **kwargs, channel=custom_channel + *args, **kwargs, channel=self._build_grpc_channel ), ) self._is_closed = CrossSync.Event() @@ -235,6 +238,23 @@ def __init__( stacklevel=2, ) + def _build_grpc_channel(self, *args, **kwargs): + if self._emulator_host is not None: + # emulators use insecure channel + return insecure_channel(self._emulator_host) + create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) + # interceptors are handled differently between async and sync, because of differences in the grpc and gapic layers + if CrossSync.is_async: + # for async, add interceptors to the creation function + create_channel_fn = partial(create_channel_fn, interceptors=[_LoggingClientAIOInterceptor()]) + return _AsyncReplaceableChannel(create_channel_fn) + else: + # for sync, chain interceptors using grpc.channel.intercept + # LoggingClientInterceptor not needed, since it is chained in the gapic layer + return TransportType.create_channel(*args, **kwargs) + + + @staticmethod def _client_version() -> str: """ @@ -376,32 +396,11 @@ async def _manage_channel( break start_timestamp = time.monotonic() # prepare new channel for use - # TODO: refactor to avoid using internal references: https://github.com/googleapis/python-bigtable/issues/1094 - old_channel = self.transport.grpc_channel - new_channel = self.transport.create_channel() - if CrossSync.is_async: - new_channel._unary_unary_interceptors.append( - self.transport._interceptor - ) - else: - new_channel = intercept_channel( - new_channel, self.transport._interceptor - ) + new_channel = self.transport.grpc_channel.create_channel() await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure - self.transport._grpc_channel = new_channel - self.transport._logged_channel = new_channel - # invalidate caches - self.transport._stubs = {} - self.transport._prep_wrapped_messages(self.client_info) - # give old_channel a chance to complete existing rpcs - if CrossSync.is_async: - await old_channel.close(grace_period) - else: - if grace_period: - self._is_closed.wait(grace_period) # type: ignore - old_channel.close() # type: ignore - # subtract thed time spent waiting for the channel to be replaced + await self.transport.grpc_channel.replace_channel(new_channel, grace_period) + # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) diff --git a/google/cloud/bigtable/data/_async/replaceable_channel.py b/google/cloud/bigtable/data/_async/replaceable_channel.py new file mode 100644 index 000000000..2105cfd72 --- /dev/null +++ b/google/cloud/bigtable/data/_async/replaceable_channel.py @@ -0,0 +1,87 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +# +from __future__ import annotations + +from typing import Callable + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.bigtable.data._cross_sync import CrossSync + +class _AsyncReplaceableChannel(aio.Channel): + """ + A wrapper around a gRPC channel. All methods are passed + through to the underlying channel. + """ + + def __init__(self, channel_fn: Callable[[], aio.Channel]): + self._channel_fn = channel_fn + self._channel = channel_fn() + + def unary_unary(self, *args, **kwargs): + return self._channel.unary_unary(*args, **kwargs) + + def unary_stream(self, *args, **kwargs): + return self._channel.unary_stream(*args, **kwargs) + + def stream_unary(self, *args, **kwargs): + return self._channel.stream_unary(*args, **kwargs) + + def stream_stream(self, *args, **kwargs): + return self._channel.stream_stream(*args, **kwargs) + + async def close(self, grace=None): + return await self._channel.close(grace=grace) + + async def channel_ready(self): + return await self._channel.channel_ready() + + async def __aenter__(self): + await self._channel.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self._channel.__aexit__(exc_type, exc_val, exc_tb) + + def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + return self._channel.get_state(try_to_connect=try_to_connect) + + async def wait_for_state_change(self, last_observed_state): + return await self._channel.wait_for_state_change(last_observed_state) + + @property + def wrapped_channel(self): + return self._channel + + def create_channel(self) -> aio.Channel: + return self._channel_fn() + + async def replace_channel(self, new_channel: aio.Channel, grace_period: float | None) -> aio.Channel: + old_channel = self._channel + self._channel = new_channel + # give old_channel a chance to complete existing rpcs + if CrossSync.is_async: + await old_channel.close(grace_period) + else: + if grace_period: + self._is_closed.wait(grace_period) # type: ignore + old_channel.close() # type: ignore + return old_channel + + @property + def _unary_unary_interceptors(self): + # return empty list for compatibility with gapic layer + return [] \ No newline at end of file From 5e107fcbce205bccbabe2aca353269ee9eb91da1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Jul 2025 18:55:11 -0700 Subject: [PATCH 02/60] got unit tests working --- google/cloud/bigtable/data/_async/client.py | 25 +++++++------- .../data/_async/replaceable_channel.py | 20 ++++++----- .../bigtable/transports/grpc_asyncio.py | 1 - tests/system/data/test_system_async.py | 7 ++-- tests/unit/data/_async/test_client.py | 34 ++++++++----------- 5 files changed, 42 insertions(+), 45 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 2c6cd207d..6635acdf1 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -238,19 +238,15 @@ def __init__( stacklevel=2, ) - def _build_grpc_channel(self, *args, **kwargs): + def _build_grpc_channel(self, *args, **kwargs) -> _AsyncReplaceableChannel: if self._emulator_host is not None: # emulators use insecure channel - return insecure_channel(self._emulator_host) - create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) - # interceptors are handled differently between async and sync, because of differences in the grpc and gapic layers + create_channel_fn = partial(insecure_channel, self._emulator_host) + else: + create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) if CrossSync.is_async: - # for async, add interceptors to the creation function - create_channel_fn = partial(create_channel_fn, interceptors=[_LoggingClientAIOInterceptor()]) return _AsyncReplaceableChannel(create_channel_fn) else: - # for sync, chain interceptors using grpc.channel.intercept - # LoggingClientInterceptor not needed, since it is chained in the gapic layer return TransportType.create_channel(*args, **kwargs) @@ -377,13 +373,18 @@ async def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ + channel = self.transport.grpc_channel + if not isinstance(self.transport.grpc_channel, _AsyncReplaceableChannel): + warnings.warn("Channel does not support auto-refresh.") + return + channel: _AsyncReplaceableChannel = self.transport.grpc_channel first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) next_sleep = max(first_refresh - time.monotonic(), 0) if next_sleep > 0: # warm the current channel immediately - await self._ping_and_warm_instances(channel=self.transport.grpc_channel) + await self._ping_and_warm_instances(channel=channel) # continuously refresh the channel every `refresh_interval` seconds while not self._is_closed.is_set(): await CrossSync.event_wait( @@ -396,10 +397,10 @@ async def _manage_channel( break start_timestamp = time.monotonic() # prepare new channel for use - new_channel = self.transport.grpc_channel.create_channel() - await self._ping_and_warm_instances(channel=new_channel) + new_sub_channel = channel.create_channel() + await self._ping_and_warm_instances(channel=new_sub_channel) # cycle channel out of use, with long grace window before closure - await self.transport.grpc_channel.replace_channel(new_channel, grace_period) + await channel.replace_wrapped_channel(new_sub_channel, grace_period) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) diff --git a/google/cloud/bigtable/data/_async/replaceable_channel.py b/google/cloud/bigtable/data/_async/replaceable_channel.py index 2105cfd72..81e9364a1 100644 --- a/google/cloud/bigtable/data/_async/replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/replaceable_channel.py @@ -62,15 +62,19 @@ def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: async def wait_for_state_change(self, last_observed_state): return await self._channel.wait_for_state_change(last_observed_state) - @property - def wrapped_channel(self): - return self._channel - def create_channel(self) -> aio.Channel: return self._channel_fn() - async def replace_channel(self, new_channel: aio.Channel, grace_period: float | None) -> aio.Channel: + async def replace_wrapped_channel(self, new_channel: aio.Channel, grace_period: float | None, copy_async_interceptors: bool=True) -> aio.Channel: old_channel = self._channel + if CrossSync.is_async and copy_async_interceptors: + # copy over interceptors + # this is needed because of how gapic attaches the LoggingClientAIOInterceptor + # sync channels add interceptors by wrapping, so this step isn't needed + new_channel._unary_unary_interceptors = old_channel._unary_unary_interceptors + new_channel._unary_stream_interceptors = old_channel._unary_stream_interceptors + new_channel._stream_unary_interceptors = old_channel._stream_unary_interceptors + new_channel._stream_stream_interceptors = old_channel._stream_stream_interceptors self._channel = new_channel # give old_channel a chance to complete existing rpcs if CrossSync.is_async: @@ -81,7 +85,5 @@ async def replace_channel(self, new_channel: aio.Channel, grace_period: float | old_channel.close() # type: ignore return old_channel - @property - def _unary_unary_interceptors(self): - # return empty list for compatibility with gapic layer - return [] \ No newline at end of file + def __getattr__(self, name): + return getattr(self._channel, name) \ No newline at end of file diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py index cebee0208..5572b2105 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py @@ -290,7 +290,6 @@ def __init__( always_use_jwt_access=always_use_jwt_access, api_audience=api_audience, ) - if not self._grpc_channel: # initialize with the provided callable or the default channel channel_init = channel or type(self).create_channel diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index b59131414..9f4fa7abb 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -252,13 +252,14 @@ async def test_channel_refresh(self, table_id, instance_id, temp_rows): await CrossSync.yield_to_event_loop() async with client.get_table(instance_id, table_id) as table: rows = await table.read_rows({}) - first_channel = client.transport.grpc_channel + channel_wrapper = client.transport.grpc_channel + first_channel = client.transport.grpc_channel._channel assert len(rows) == 2 await CrossSync.sleep(2) rows_after_refresh = await table.read_rows({}) assert len(rows_after_refresh) == 2 - assert client.transport.grpc_channel is not first_channel - print(table) + assert client.transport.grpc_channel is channel_wrapper + assert client.transport.grpc_channel._channel is not first_channel finally: await client.close() diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 5e7302d75..17d3b7ac0 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -223,7 +223,7 @@ async def test__start_background_channel_refresh_task_exists(self): @CrossSync.pytest async def test__start_background_channel_refresh(self): # should create background tasks for each channel - client = self._make_client(project="project-id") + client = self._make_client(project="project-id", use_emulator=False) with mock.patch.object( client, "_ping_and_warm_instances", CrossSync.Mock() ) as ping_and_warm: @@ -366,7 +366,7 @@ async def test__manage_channel_first_sleep( with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = asyncio.CancelledError try: - client = self._make_client(project="project-id") + client = self._make_client(project="project-id", use_emulator=False) client._channel_init_time = -wait_time await client._manage_channel(refresh_interval, refresh_interval) except asyncio.CancelledError: @@ -395,32 +395,29 @@ async def test__manage_channel_ping_and_warm(self): _LoggingClientInterceptor as Interceptor, ) - client_mock = mock.Mock() - client_mock.transport._interceptor = Interceptor() - client_mock._is_closed.is_set.return_value = False - client_mock._channel_init_time = time.monotonic() - orig_channel = client_mock.transport.grpc_channel + client = self._make_client(project="project-id", use_emulator=True) + orig_channel = client.transport.grpc_channel # should ping an warm all new channels, and old channels if sleeping sleep_tuple = ( (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") ) - with mock.patch.object(*sleep_tuple): - # stop process after close is called - orig_channel.close.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = CrossSync.Mock() + with mock.patch.object(*sleep_tuple) as sleep_mock: + # stop process after loop + sleep_mock.side_effect = [None, asyncio.CancelledError] + ping_and_warm = client._ping_and_warm_instances = CrossSync.Mock() # should ping and warm old channel then new if sleep > 0 try: - await self._get_target_class()._manage_channel(client_mock, 10) + await client._manage_channel(10) except asyncio.CancelledError: pass # should have called at loop start, and after replacement assert ping_and_warm.call_count == 2 # should have replaced channel once - assert client_mock.transport._grpc_channel != orig_channel + assert client.transport.grpc_channel._channel != orig_channel # make sure new and old channels were warmed called_with = [call[1]["channel"] for call in ping_and_warm.call_args_list] assert orig_channel in called_with - assert client_mock.transport.grpc_channel in called_with + assert client.transport.grpc_channel._channel in called_with @CrossSync.pytest @pytest.mark.parametrize( @@ -438,8 +435,6 @@ async def test__manage_channel_sleeps( import time import random - channel = mock.Mock() - channel.close = CrossSync.Mock() with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time_mock: @@ -448,8 +443,7 @@ async def test__manage_channel_sleeps( sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError ] - client = self._make_client(project="project-id") - client.transport._grpc_channel = channel + client = self._make_client(project="project-id", use_emulator=True) with mock.patch.object( client.transport, "create_channel", CrossSync.Mock ): @@ -478,7 +472,7 @@ async def test__manage_channel_random(self): uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id") + client = self._make_client(project="project-id", use_emulator=False) except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -512,7 +506,7 @@ async def test__manage_channel_refresh(self, num_cycles): CrossSync.grpc_helpers, "create_channel" ) as create_channel: create_channel.return_value = new_channel - client = self._make_client(project="project-id") + client = self._make_client(project="project-id", use_emulator=False) create_channel.reset_mock() try: await client._manage_channel( From c4a97e1803120eb45646a15aada8f5811e89b471 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Jul 2025 20:12:33 -0700 Subject: [PATCH 03/60] put back in cache invalidation --- google/cloud/bigtable/data/_async/client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6635acdf1..4ca55ed8b 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -401,6 +401,9 @@ async def _manage_channel( await self._ping_and_warm_instances(channel=new_sub_channel) # cycle channel out of use, with long grace window before closure await channel.replace_wrapped_channel(new_sub_channel, grace_period) + # invalidate caches + self.transport._stubs = {} + self.transport._prep_wrapped_messages(self.client_info) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) From e71b1d581cad59b11f5f781d9a2624e853b5dd30 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Jul 2025 20:12:49 -0700 Subject: [PATCH 04/60] added wrapped multicallables to avoid cache invalidation --- google/cloud/bigtable/data/_async/client.py | 3 - .../data/_async/replaceable_channel.py | 96 +++++++++++++++---- 2 files changed, 78 insertions(+), 21 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 4ca55ed8b..6635acdf1 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -401,9 +401,6 @@ async def _manage_channel( await self._ping_and_warm_instances(channel=new_sub_channel) # cycle channel out of use, with long grace window before closure await channel.replace_wrapped_channel(new_sub_channel, grace_period) - # invalidate caches - self.transport._stubs = {} - self.transport._prep_wrapped_messages(self.client_info) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) diff --git a/google/cloud/bigtable/data/_async/replaceable_channel.py b/google/cloud/bigtable/data/_async/replaceable_channel.py index 81e9364a1..30eed6017 100644 --- a/google/cloud/bigtable/data/_async/replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/replaceable_channel.py @@ -21,27 +21,80 @@ from google.cloud.bigtable.data._cross_sync import CrossSync -class _AsyncReplaceableChannel(aio.Channel): +class _WrappedMultiCallable: """ - A wrapper around a gRPC channel. All methods are passed - through to the underlying channel. + Wrapper class that implements the grpc MultiCallable interface. + Allows generic functions that return calls to pass checks for + MultiCallable objects. """ - def __init__(self, channel_fn: Callable[[], aio.Channel]): - self._channel_fn = channel_fn - self._channel = channel_fn() + def __init__(self, call_factory: Callable[..., aio.Call]): + self._call_factory = call_factory + + def __call__(self, *args, **kwargs) -> aio.Call: + return self._call_factory(*args, **kwargs) + - def unary_unary(self, *args, **kwargs): - return self._channel.unary_unary(*args, **kwargs) +class WrappedUnaryUnaryMultiCallable( + _WrappedMultiCallable, aio.UnaryUnaryMultiCallable +): + pass - def unary_stream(self, *args, **kwargs): - return self._channel.unary_stream(*args, **kwargs) - def stream_unary(self, *args, **kwargs): - return self._channel.stream_unary(*args, **kwargs) +class WrappedUnaryStreamMultiCallable( + _WrappedMultiCallable, aio.UnaryStreamMultiCallable +): + pass - def stream_stream(self, *args, **kwargs): - return self._channel.stream_stream(*args, **kwargs) + +class WrappedStreamUnaryMultiCallable( + _WrappedMultiCallable, aio.StreamUnaryMultiCallable +): + pass + + +class WrappedStreamStreamMultiCallable( + _WrappedMultiCallable, aio.StreamStreamMultiCallable +): + pass + + +class _AsyncWrappedChannel(aio.Channel): + """ + A wrapper around a gRPC channel. All methods are passed + through to the underlying channel. + """ + + def __init__(self, channel: aio.Channel): + self._channel = channel + + def unary_unary(self, *args, **kwargs) -> grpc.aio.UnaryUnaryMultiCallable: + return WrappedUnaryUnaryMultiCallable( + lambda *call_args, **call_kwargs: self._channel.unary_unary( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def unary_stream(self, *args, **kwargs) -> grpc.aio.UnaryStreamMultiCallable: + return WrappedUnaryStreamMultiCallable( + lambda *call_args, **call_kwargs: self._channel.unary_stream( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable: + return WrappedStreamUnaryMultiCallable( + lambda *call_args, **call_kwargs: self._channel.stream_unary( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable: + return WrappedStreamStreamMultiCallable( + lambda *call_args, **call_kwargs: self._channel.stream_stream( + *args, **kwargs + )(*call_args, **call_kwargs) + ) async def close(self, grace=None): return await self._channel.close(grace=grace) @@ -62,6 +115,16 @@ def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: async def wait_for_state_change(self, last_observed_state): return await self._channel.wait_for_state_change(last_observed_state) + def __getattr__(self, name): + return getattr(self._channel, name) + + +class _AsyncReplaceableChannel(_AsyncWrappedChannel): + + def __init__(self, channel_fn: Callable[[], aio.Channel]): + self._channel_fn = channel_fn + self._channel = channel_fn() + def create_channel(self) -> aio.Channel: return self._channel_fn() @@ -83,7 +146,4 @@ async def replace_wrapped_channel(self, new_channel: aio.Channel, grace_period: if grace_period: self._is_closed.wait(grace_period) # type: ignore old_channel.close() # type: ignore - return old_channel - - def __getattr__(self, name): - return getattr(self._channel, name) \ No newline at end of file + return old_channel \ No newline at end of file From b81a9bec90de34f0203ad8d6b9d0bcddd718bc1b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 10:46:42 -0700 Subject: [PATCH 05/60] added crosssync, moved close logic back to client --- ...ble_channel.py => _replaceable_channel.py} | 84 +++++++++++-------- google/cloud/bigtable/data/_async/client.py | 28 ++++--- 2 files changed, 65 insertions(+), 47 deletions(-) rename google/cloud/bigtable/data/_async/{replaceable_channel.py => _replaceable_channel.py} (58%) diff --git a/google/cloud/bigtable/data/_async/replaceable_channel.py b/google/cloud/bigtable/data/_async/_replaceable_channel.py similarity index 58% rename from google/cloud/bigtable/data/_async/replaceable_channel.py rename to google/cloud/bigtable/data/_async/_replaceable_channel.py index 30eed6017..2a816919c 100644 --- a/google/cloud/bigtable/data/_async/replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/_replaceable_channel.py @@ -16,11 +16,28 @@ from typing import Callable -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - from google.cloud.bigtable.data._cross_sync import CrossSync +from grpc import ChannelConnectivity + +if CrossSync.is_async: + from grpc.aio import Call + from grpc.aio import Channel + from grpc.aio import UnaryUnaryMultiCallable + from grpc.aio import UnaryStreamMultiCallable + from grpc.aio import StreamUnaryMultiCallable + from grpc.aio import StreamStreamMultiCallable +else: + from grpc import Call + from grpc import Channel + from grpc import UnaryUnaryMultiCallable + from grpc import UnaryStreamMultiCallable + from grpc import StreamUnaryMultiCallable + from grpc import StreamStreamMultiCallable + +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._replaceable_channel" + +@CrossSync.convert_class class _WrappedMultiCallable: """ Wrapper class that implements the grpc MultiCallable interface. @@ -28,68 +45,69 @@ class _WrappedMultiCallable: MultiCallable objects. """ - def __init__(self, call_factory: Callable[..., aio.Call]): + def __init__(self, call_factory: Callable[..., Call]): self._call_factory = call_factory - def __call__(self, *args, **kwargs) -> aio.Call: + def __call__(self, *args, **kwargs) -> Call: return self._call_factory(*args, **kwargs) class WrappedUnaryUnaryMultiCallable( - _WrappedMultiCallable, aio.UnaryUnaryMultiCallable + _WrappedMultiCallable, UnaryUnaryMultiCallable ): pass class WrappedUnaryStreamMultiCallable( - _WrappedMultiCallable, aio.UnaryStreamMultiCallable + _WrappedMultiCallable, UnaryStreamMultiCallable ): pass class WrappedStreamUnaryMultiCallable( - _WrappedMultiCallable, aio.StreamUnaryMultiCallable + _WrappedMultiCallable, StreamUnaryMultiCallable ): pass class WrappedStreamStreamMultiCallable( - _WrappedMultiCallable, aio.StreamStreamMultiCallable + _WrappedMultiCallable, StreamStreamMultiCallable ): pass -class _AsyncWrappedChannel(aio.Channel): +@CrossSync.convert_class(sync_name="_WrappedChannel", rm_aio=True) +class _AsyncWrappedChannel(Channel): """ A wrapper around a gRPC channel. All methods are passed through to the underlying channel. """ - def __init__(self, channel: aio.Channel): + def __init__(self, channel: Channel): self._channel = channel - def unary_unary(self, *args, **kwargs) -> grpc.aio.UnaryUnaryMultiCallable: + def unary_unary(self, *args, **kwargs) -> UnaryUnaryMultiCallable: return WrappedUnaryUnaryMultiCallable( lambda *call_args, **call_kwargs: self._channel.unary_unary( *args, **kwargs )(*call_args, **call_kwargs) ) - def unary_stream(self, *args, **kwargs) -> grpc.aio.UnaryStreamMultiCallable: + def unary_stream(self, *args, **kwargs) -> UnaryStreamMultiCallable: return WrappedUnaryStreamMultiCallable( lambda *call_args, **call_kwargs: self._channel.unary_stream( *args, **kwargs )(*call_args, **call_kwargs) ) - def stream_unary(self, *args, **kwargs) -> grpc.aio.StreamUnaryMultiCallable: + def stream_unary(self, *args, **kwargs) -> StreamUnaryMultiCallable: return WrappedStreamUnaryMultiCallable( lambda *call_args, **call_kwargs: self._channel.stream_unary( *args, **kwargs )(*call_args, **call_kwargs) ) - def stream_stream(self, *args, **kwargs) -> grpc.aio.StreamStreamMultiCallable: + def stream_stream(self, *args, **kwargs) -> StreamStreamMultiCallable: return WrappedStreamStreamMultiCallable( lambda *call_args, **call_kwargs: self._channel.stream_stream( *args, **kwargs @@ -102,14 +120,16 @@ async def close(self, grace=None): async def channel_ready(self): return await self._channel.channel_ready() + @CrossSync.convert(sync_name="__enter__", replace_symbols={"__aenter__": "__enter__"}) async def __aenter__(self): await self._channel.__aenter__() return self + @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): return await self._channel.__aexit__(exc_type, exc_val, exc_tb) - def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + def get_state(self, try_to_connect: bool = False) -> ChannelConnectivity: return self._channel.get_state(try_to_connect=try_to_connect) async def wait_for_state_change(self, last_observed_state): @@ -118,32 +138,26 @@ async def wait_for_state_change(self, last_observed_state): def __getattr__(self, name): return getattr(self._channel, name) - +@CrossSync.convert_class(sync_name="_ReplaceableChannel", replace_symbols={"_AsyncWrappedChannel": "_WrappedChannel"}) class _AsyncReplaceableChannel(_AsyncWrappedChannel): - def __init__(self, channel_fn: Callable[[], aio.Channel]): + def __init__(self, channel_fn: Callable[[], Channel]): self._channel_fn = channel_fn self._channel = channel_fn() - def create_channel(self) -> aio.Channel: - return self._channel_fn() - - async def replace_wrapped_channel(self, new_channel: aio.Channel, grace_period: float | None, copy_async_interceptors: bool=True) -> aio.Channel: - old_channel = self._channel - if CrossSync.is_async and copy_async_interceptors: + def create_channel(self) -> Channel: + new_channel = self._channel_fn() + if CrossSync.is_async: # copy over interceptors # this is needed because of how gapic attaches the LoggingClientAIOInterceptor # sync channels add interceptors by wrapping, so this step isn't needed - new_channel._unary_unary_interceptors = old_channel._unary_unary_interceptors - new_channel._unary_stream_interceptors = old_channel._unary_stream_interceptors - new_channel._stream_unary_interceptors = old_channel._stream_unary_interceptors - new_channel._stream_stream_interceptors = old_channel._stream_stream_interceptors + new_channel._unary_unary_interceptors = self._channel._unary_unary_interceptors + new_channel._unary_stream_interceptors = self._channel._unary_stream_interceptors + new_channel._stream_unary_interceptors = self._channel._stream_unary_interceptors + new_channel._stream_stream_interceptors = self._channel._stream_stream_interceptors + return new_channel + + def replace_wrapped_channel(self, new_channel: Channel) -> Channel: + old_channel = self._channel self._channel = new_channel - # give old_channel a chance to complete existing rpcs - if CrossSync.is_async: - await old_channel.close(grace_period) - else: - if grace_period: - self._is_closed.wait(grace_period) # type: ignore - old_channel.close() # type: ignore return old_channel \ No newline at end of file diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6635acdf1..1134bd4da 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -96,13 +96,14 @@ _LoggingClientAIOInterceptor ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._async.replaceable_channel import _AsyncReplaceableChannel + from google.cloud.bigtable.data._async._replaceable_channel import _AsyncReplaceableChannel else: from typing import Iterable # noqa: F401 from grpc import insecure_channel from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE + from google.cloud.bigtable.data._async._replaceable_channel import _ReplaceableChannel if TYPE_CHECKING: @@ -238,17 +239,14 @@ def __init__( stacklevel=2, ) + @CrossSync.convert(replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"}) def _build_grpc_channel(self, *args, **kwargs) -> _AsyncReplaceableChannel: if self._emulator_host is not None: # emulators use insecure channel create_channel_fn = partial(insecure_channel, self._emulator_host) else: create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) - if CrossSync.is_async: - return _AsyncReplaceableChannel(create_channel_fn) - else: - return TransportType.create_channel(*args, **kwargs) - + return _AsyncReplaceableChannel(create_channel_fn) @staticmethod @@ -373,18 +371,17 @@ async def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ - channel = self.transport.grpc_channel if not isinstance(self.transport.grpc_channel, _AsyncReplaceableChannel): warnings.warn("Channel does not support auto-refresh.") return - channel: _AsyncReplaceableChannel = self.transport.grpc_channel + super_channel: _AsyncReplaceableChannel = self.transport.grpc_channel first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) next_sleep = max(first_refresh - time.monotonic(), 0) if next_sleep > 0: # warm the current channel immediately - await self._ping_and_warm_instances(channel=channel) + await self._ping_and_warm_instances(channel=super_channel) # continuously refresh the channel every `refresh_interval` seconds while not self._is_closed.is_set(): await CrossSync.event_wait( @@ -397,10 +394,17 @@ async def _manage_channel( break start_timestamp = time.monotonic() # prepare new channel for use - new_sub_channel = channel.create_channel() - await self._ping_and_warm_instances(channel=new_sub_channel) + new_channel = super_channel.create_channel() + await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure - await channel.replace_wrapped_channel(new_sub_channel, grace_period) + old_channel = super_channel.replace_wrapped_channel(new_channel, grace_period) + # give old_channel a chance to complete existing rpcs + if CrossSync.is_async: + await old_channel.close(grace_period) + else: + if grace_period: + self._is_closed.wait(grace_period) # type: ignore + old_channel.close() # type: ignore # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) From a1dffb559de55af8704c7f2c623b60ab28996aab Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 10:48:18 -0700 Subject: [PATCH 06/60] generated sync code --- .../_sync_autogen/_replaceable_channel.py | 133 ++++++++++++++++++ .../bigtable/data/_sync_autogen/client.py | 30 ++-- tests/system/data/test_system_autogen.py | 7 +- tests/unit/data/_sync_autogen/test_client.py | 36 ++--- 4 files changed, 168 insertions(+), 38 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py diff --git a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py new file mode 100644 index 000000000..a936bf598 --- /dev/null +++ b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py @@ -0,0 +1,133 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +# + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +from typing import Callable +from grpc import ChannelConnectivity +from grpc import Call +from grpc import Channel +from grpc import UnaryUnaryMultiCallable +from grpc import UnaryStreamMultiCallable +from grpc import StreamUnaryMultiCallable +from grpc import StreamStreamMultiCallable + + +class _WrappedMultiCallable: + """ + Wrapper class that implements the grpc MultiCallable interface. + Allows generic functions that return calls to pass checks for + MultiCallable objects. + """ + + def __init__(self, call_factory: Callable[..., Call]): + self._call_factory = call_factory + + def __call__(self, *args, **kwargs) -> Call: + return self._call_factory(*args, **kwargs) + + +class WrappedUnaryUnaryMultiCallable(_WrappedMultiCallable, UnaryUnaryMultiCallable): + pass + + +class WrappedUnaryStreamMultiCallable(_WrappedMultiCallable, UnaryStreamMultiCallable): + pass + + +class WrappedStreamUnaryMultiCallable(_WrappedMultiCallable, StreamUnaryMultiCallable): + pass + + +class WrappedStreamStreamMultiCallable( + _WrappedMultiCallable, StreamStreamMultiCallable +): + pass + + +class _WrappedChannel(Channel): + """ + A wrapper around a gRPC channel. All methods are passed + through to the underlying channel. + """ + + def __init__(self, channel: Channel): + self._channel = channel + + def unary_unary(self, *args, **kwargs) -> UnaryUnaryMultiCallable: + return WrappedUnaryUnaryMultiCallable( + lambda *call_args, **call_kwargs: self._channel.unary_unary( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def unary_stream(self, *args, **kwargs) -> UnaryStreamMultiCallable: + return WrappedUnaryStreamMultiCallable( + lambda *call_args, **call_kwargs: self._channel.unary_stream( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def stream_unary(self, *args, **kwargs) -> StreamUnaryMultiCallable: + return WrappedStreamUnaryMultiCallable( + lambda *call_args, **call_kwargs: self._channel.stream_unary( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def stream_stream(self, *args, **kwargs) -> StreamStreamMultiCallable: + return WrappedStreamStreamMultiCallable( + lambda *call_args, **call_kwargs: self._channel.stream_stream( + *args, **kwargs + )(*call_args, **call_kwargs) + ) + + def close(self, grace=None): + return self._channel.close(grace=grace) + + def channel_ready(self): + return self._channel.channel_ready() + + def __enter__(self): + self._channel.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._channel.__exit__(exc_type, exc_val, exc_tb) + + def get_state(self, try_to_connect: bool = False) -> ChannelConnectivity: + return self._channel.get_state(try_to_connect=try_to_connect) + + def wait_for_state_change(self, last_observed_state): + return self._channel.wait_for_state_change(last_observed_state) + + def __getattr__(self, name): + return getattr(self._channel, name) + + +class _ReplaceableChannel(_WrappedChannel): + def __init__(self, channel_fn: Callable[[], Channel]): + self._channel_fn = channel_fn + self._channel = channel_fn() + + def create_channel(self) -> Channel: + new_channel = self._channel_fn() + return new_channel + + def replace_wrapped_channel(self, new_channel: Channel) -> Channel: + old_channel = self._channel + self._channel = new_channel + return old_channel diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index b36bf359a..d4fdf5d95 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -75,11 +75,11 @@ from google.cloud.bigtable.data._cross_sync import CrossSync from typing import Iterable from grpc import insecure_channel -from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import ( BigtableGrpcTransport as TransportType, ) from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE +from google.cloud.bigtable.data._async._replaceable_channel import _ReplaceableChannel if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -131,7 +131,6 @@ def __init__( client_options = cast( Optional[client_options_lib.ClientOptions], client_options ) - custom_channel = None self._emulator_host = os.getenv(BIGTABLE_EMULATOR) if self._emulator_host is not None: warnings.warn( @@ -139,7 +138,6 @@ def __init__( RuntimeWarning, stacklevel=2, ) - custom_channel = insecure_channel(self._emulator_host) if credentials is None: credentials = google.auth.credentials.AnonymousCredentials() if project is None: @@ -155,7 +153,7 @@ def __init__( client_options=client_options, client_info=self.client_info, transport=lambda *args, **kwargs: TransportType( - *args, **kwargs, channel=custom_channel + *args, **kwargs, channel=self._build_grpc_channel ), ) self._is_closed = CrossSync._Sync_Impl.Event() @@ -179,6 +177,13 @@ def __init__( stacklevel=2, ) + def _build_grpc_channel(self, *args, **kwargs) -> _ReplaceableChannel: + if self._emulator_host is not None: + create_channel_fn = partial(insecure_channel, self._emulator_host) + else: + create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) + return _ReplaceableChannel(create_channel_fn) + @staticmethod def _client_version() -> str: """Helper function to return the client version string for this client""" @@ -277,12 +282,16 @@ def _manage_channel( between `refresh_interval_min` and `refresh_interval_max` grace_period: time to allow previous channel to serve existing requests before closing, in seconds""" + if not isinstance(self.transport.grpc_channel, _AsyncReplaceableChannel): + warnings.warn("Channel does not support auto-refresh.") + return + super_channel: _AsyncReplaceableChannel = self.transport.grpc_channel first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) next_sleep = max(first_refresh - time.monotonic(), 0) if next_sleep > 0: - self._ping_and_warm_instances(channel=self.transport.grpc_channel) + self._ping_and_warm_instances(channel=super_channel) while not self._is_closed.is_set(): CrossSync._Sync_Impl.event_wait( self._is_closed, next_sleep, async_break_early=False @@ -290,14 +299,11 @@ def _manage_channel( if self._is_closed.is_set(): break start_timestamp = time.monotonic() - old_channel = self.transport.grpc_channel - new_channel = self.transport.create_channel() - new_channel = intercept_channel(new_channel, self.transport._interceptor) + new_channel = super_channel.create_channel() self._ping_and_warm_instances(channel=new_channel) - self.transport._grpc_channel = new_channel - self.transport._logged_channel = new_channel - self.transport._stubs = {} - self.transport._prep_wrapped_messages(self.client_info) + old_channel = super_channel.replace_wrapped_channel( + new_channel, grace_period + ) if grace_period: self._is_closed.wait(grace_period) old_channel.close() diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 6b2006d7b..7d99805b7 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -201,13 +201,14 @@ def test_channel_refresh(self, table_id, instance_id, temp_rows): CrossSync._Sync_Impl.yield_to_event_loop() with client.get_table(instance_id, table_id) as table: rows = table.read_rows({}) - first_channel = client.transport.grpc_channel + channel_wrapper = client.transport.grpc_channel + first_channel = client.transport.grpc_channel._channel assert len(rows) == 2 CrossSync._Sync_Impl.sleep(2) rows_after_refresh = table.read_rows({}) assert len(rows_after_refresh) == 2 - assert client.transport.grpc_channel is not first_channel - print(table) + assert client.transport.grpc_channel is channel_wrapper + assert client.transport.grpc_channel._channel is not first_channel finally: client.close() diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index 38866c9dd..04ce7dd5b 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -176,7 +176,7 @@ def test__start_background_channel_refresh_task_exists(self): client.close() def test__start_background_channel_refresh(self): - client = self._make_client(project="project-id") + client = self._make_client(project="project-id", use_emulator=False) with mock.patch.object( client, "_ping_and_warm_instances", CrossSync._Sync_Impl.Mock() ) as ping_and_warm: @@ -282,7 +282,7 @@ def test__manage_channel_first_sleep( with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: sleep.side_effect = asyncio.CancelledError try: - client = self._make_client(project="project-id") + client = self._make_client(project="project-id", use_emulator=False) client._channel_init_time = -wait_time client._manage_channel(refresh_interval, refresh_interval) except asyncio.CancelledError: @@ -296,36 +296,29 @@ def test__manage_channel_first_sleep( def test__manage_channel_ping_and_warm(self): """_manage channel should call ping and warm internally""" - import time import threading - from google.cloud.bigtable_v2.services.bigtable.transports.grpc import ( - _LoggingClientInterceptor as Interceptor, - ) - client_mock = mock.Mock() - client_mock.transport._interceptor = Interceptor() - client_mock._is_closed.is_set.return_value = False - client_mock._channel_init_time = time.monotonic() - orig_channel = client_mock.transport.grpc_channel + client = self._make_client(project="project-id", use_emulator=True) + orig_channel = client.transport.grpc_channel sleep_tuple = ( (asyncio, "sleep") if CrossSync._Sync_Impl.is_async else (threading.Event, "wait") ) - with mock.patch.object(*sleep_tuple): - orig_channel.close.side_effect = asyncio.CancelledError + with mock.patch.object(*sleep_tuple) as sleep_mock: + sleep_mock.side_effect = [None, asyncio.CancelledError] ping_and_warm = ( - client_mock._ping_and_warm_instances + client._ping_and_warm_instances ) = CrossSync._Sync_Impl.Mock() try: - self._get_target_class()._manage_channel(client_mock, 10) + client._manage_channel(10) except asyncio.CancelledError: pass assert ping_and_warm.call_count == 2 - assert client_mock.transport._grpc_channel != orig_channel + assert client.transport.grpc_channel._channel != orig_channel called_with = [call[1]["channel"] for call in ping_and_warm.call_args_list] assert orig_channel in called_with - assert client_mock.transport.grpc_channel in called_with + assert client.transport.grpc_channel._channel in called_with @pytest.mark.parametrize( "refresh_interval, num_cycles, expected_sleep", @@ -335,8 +328,6 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle import time import random - channel = mock.Mock() - channel.close = CrossSync._Sync_Impl.Mock() with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time_mock: @@ -345,8 +336,7 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError ] - client = self._make_client(project="project-id") - client.transport._grpc_channel = channel + client = self._make_client(project="project-id", use_emulator=True) with mock.patch.object( client.transport, "create_channel", CrossSync._Sync_Impl.Mock ): @@ -374,7 +364,7 @@ def test__manage_channel_random(self): uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id") + client = self._make_client(project="project-id", use_emulator=False) except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -405,7 +395,7 @@ def test__manage_channel_refresh(self, num_cycles): CrossSync._Sync_Impl.grpc_helpers, "create_channel" ) as create_channel: create_channel.return_value = new_channel - client = self._make_client(project="project-id") + client = self._make_client(project="project-id", use_emulator=False) create_channel.reset_mock() try: client._manage_channel( From e3ec02bd6972405643b0b7858a71bb7d91b6cae2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 11:25:45 -0700 Subject: [PATCH 07/60] got tests running --- .../data/_async/_replaceable_channel.py | 24 ++++++++++++++++++- google/cloud/bigtable/data/_async/client.py | 6 ++--- .../_sync_autogen/_replaceable_channel.py | 19 +++++++++++---- .../bigtable/data/_sync_autogen/client.py | 12 +++++----- 4 files changed, 47 insertions(+), 14 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_replaceable_channel.py b/google/cloud/bigtable/data/_async/_replaceable_channel.py index 2a816919c..2cece13b1 100644 --- a/google/cloud/bigtable/data/_async/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/_replaceable_channel.py @@ -55,7 +55,15 @@ def __call__(self, *args, **kwargs) -> Call: class WrappedUnaryUnaryMultiCallable( _WrappedMultiCallable, UnaryUnaryMultiCallable ): - pass + if not CrossSync.is_async: + # add missing functions for sync unary callable + + def with_call(self, *args, **kwargs): + call = self.__call__(self, *args, **kwargs) + return call(), call + + def future(self, *args, **kwargs): + raise NotImplementedError class WrappedUnaryStreamMultiCallable( @@ -114,6 +122,8 @@ def stream_stream(self, *args, **kwargs) -> StreamStreamMultiCallable: )(*call_args, **call_kwargs) ) + # grace not supported by sync version + @CrossSync.drop async def close(self, grace=None): return await self._channel.close(grace=grace) @@ -138,6 +148,18 @@ async def wait_for_state_change(self, last_observed_state): def __getattr__(self, name): return getattr(self._channel, name) + if not CrossSync.is_async: + + def close(self): + return self._channel.close() + + def subscribe(self, callback, try_to_connect=False): + return self._channel.subscribe(callback, try_to_connect) + + def unsubscribe(self, callback): + return self._channel.unsubscribe(callback) + + @CrossSync.convert_class(sync_name="_ReplaceableChannel", replace_symbols={"_AsyncWrappedChannel": "_WrappedChannel"}) class _AsyncReplaceableChannel(_AsyncWrappedChannel): diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 1134bd4da..eecda25ec 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -103,7 +103,7 @@ from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._async._replaceable_channel import _ReplaceableChannel + from google.cloud.bigtable.data._sync_autogen._replaceable_channel import _ReplaceableChannel if TYPE_CHECKING: @@ -346,7 +346,7 @@ async def _ping_and_warm_instances( ) return [r or None for r in result_list] - @CrossSync.convert + @CrossSync.convert(replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"}) async def _manage_channel( self, refresh_interval_min: float = 60 * 35, @@ -397,7 +397,7 @@ async def _manage_channel( new_channel = super_channel.create_channel() await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure - old_channel = super_channel.replace_wrapped_channel(new_channel, grace_period) + old_channel = super_channel.replace_wrapped_channel(new_channel) # give old_channel a chance to complete existing rpcs if CrossSync.is_async: await old_channel.close(grace_period) diff --git a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py index a936bf598..523f180b0 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py @@ -41,7 +41,12 @@ def __call__(self, *args, **kwargs) -> Call: class WrappedUnaryUnaryMultiCallable(_WrappedMultiCallable, UnaryUnaryMultiCallable): - pass + def with_call(self, *args, **kwargs): + call = self.__call__(self, *args, **kwargs) + return (call(), call) + + def future(self, *args, **kwargs): + raise NotImplementedError class WrappedUnaryStreamMultiCallable(_WrappedMultiCallable, UnaryStreamMultiCallable): @@ -95,9 +100,6 @@ def stream_stream(self, *args, **kwargs) -> StreamStreamMultiCallable: )(*call_args, **call_kwargs) ) - def close(self, grace=None): - return self._channel.close(grace=grace) - def channel_ready(self): return self._channel.channel_ready() @@ -117,6 +119,15 @@ def wait_for_state_change(self, last_observed_state): def __getattr__(self, name): return getattr(self._channel, name) + def close(self): + return self._channel.close() + + def subscribe(self, callback, try_to_connect=False): + return self._channel.subscribe(callback, try_to_connect) + + def unsubscribe(self, callback): + return self._channel.unsubscribe(callback) + class _ReplaceableChannel(_WrappedChannel): def __init__(self, channel_fn: Callable[[], Channel]): diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index d4fdf5d95..6be313720 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -79,7 +79,9 @@ BigtableGrpcTransport as TransportType, ) from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE -from google.cloud.bigtable.data._async._replaceable_channel import _ReplaceableChannel +from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( + _ReplaceableChannel, +) if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -282,10 +284,10 @@ def _manage_channel( between `refresh_interval_min` and `refresh_interval_max` grace_period: time to allow previous channel to serve existing requests before closing, in seconds""" - if not isinstance(self.transport.grpc_channel, _AsyncReplaceableChannel): + if not isinstance(self.transport.grpc_channel, _ReplaceableChannel): warnings.warn("Channel does not support auto-refresh.") return - super_channel: _AsyncReplaceableChannel = self.transport.grpc_channel + super_channel: _ReplaceableChannel = self.transport.grpc_channel first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -301,9 +303,7 @@ def _manage_channel( start_timestamp = time.monotonic() new_channel = super_channel.create_channel() self._ping_and_warm_instances(channel=new_channel) - old_channel = super_channel.replace_wrapped_channel( - new_channel, grace_period - ) + old_channel = super_channel.replace_wrapped_channel(new_channel) if grace_period: self._is_closed.wait(grace_period) old_channel.close() From 4e1378359fc4db5774d3512cd7db22b0ab90ac19 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 12:47:40 -0700 Subject: [PATCH 08/60] fixed tests --- .../data/_async/_replaceable_channel.py | 2 +- google/cloud/bigtable/data/_async/client.py | 2 +- .../_sync_autogen/_replaceable_channel.py | 2 +- tests/unit/data/_async/test_client.py | 46 +++++++++-------- tests/unit/data/_sync_autogen/test_client.py | 50 +++++++++++-------- 5 files changed, 59 insertions(+), 43 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_replaceable_channel.py b/google/cloud/bigtable/data/_async/_replaceable_channel.py index 2cece13b1..e1b404581 100644 --- a/google/cloud/bigtable/data/_async/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/_replaceable_channel.py @@ -10,7 +10,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License +# limitations under the License. # from __future__ import annotations diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index eecda25ec..a39a2ff85 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -404,7 +404,7 @@ async def _manage_channel( else: if grace_period: self._is_closed.wait(grace_period) # type: ignore - old_channel.close() # type: ignore + old_channel.close() # type: ignore # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) diff --git a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py index 523f180b0..c150f4f7f 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py @@ -10,7 +10,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License +# limitations under the License. # # This file is automatically generated by CrossSync. Do not edit manually. diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 17d3b7ac0..23abeb633 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -51,13 +51,17 @@ if CrossSync.is_async: from google.api_core import grpc_helpers_async from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async._replaceable_channel import _AsyncReplaceableChannel CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) + CrossSync.add_mapping("ReplaceableChannel", _AsyncReplaceableChannel) else: from google.api_core import grpc_helpers from google.cloud.bigtable.data._sync_autogen.client import Table # noqa: F401 + from google.cloud.bigtable.data._sync_autogen._replaceable_channel import _ReplaceableChannel CrossSync.add_mapping("grpc_helpers", grpc_helpers) + CrossSync.add_mapping("ReplaceableChannel", _ReplaceableChannel) __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_client" @@ -223,11 +227,12 @@ async def test__start_background_channel_refresh_task_exists(self): @CrossSync.pytest async def test__start_background_channel_refresh(self): # should create background tasks for each channel - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") with mock.patch.object( client, "_ping_and_warm_instances", CrossSync.Mock() ) as ping_and_warm: client._emulator_host = None + client.transport._grpc_channel = CrossSync.ReplaceableChannel(mock.Mock) client._start_background_channel_refresh() assert client._channel_refresh_task is not None assert isinstance(client._channel_refresh_task, CrossSync.Task) @@ -366,7 +371,7 @@ async def test__manage_channel_first_sleep( with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = asyncio.CancelledError try: - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") client._channel_init_time = -wait_time await client._manage_channel(refresh_interval, refresh_interval) except asyncio.CancelledError: @@ -472,7 +477,7 @@ async def test__manage_channel_random(self): uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -499,26 +504,27 @@ async def test__manage_channel_refresh(self, num_cycles): expected_refresh = 0.5 grpc_lib = grpc.aio if CrossSync.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") + create_channel_mock = mock.Mock() + create_channel_mock.return_value = new_channel + refreshable_channel = CrossSync.ReplaceableChannel(create_channel_mock) with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [RuntimeError] - with mock.patch.object( - CrossSync.grpc_helpers, "create_channel" - ) as create_channel: - create_channel.return_value = new_channel - client = self._make_client(project="project-id", use_emulator=False) - create_channel.reset_mock() - try: - await client._manage_channel( - refresh_interval_min=expected_refresh, - refresh_interval_max=expected_refresh, - grace_period=0, - ) - except RuntimeError: - pass - assert sleep.call_count == num_cycles + 1 - assert create_channel.call_count == num_cycles - await client.close() + client = self._make_client(project="project-id") + client.transport._grpc_channel = refreshable_channel + create_channel_mock.reset_mock() + sleep.reset_mock() + try: + await client._manage_channel( + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=0, + ) + except RuntimeError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel_mock.call_count == num_cycles + await client.close() @CrossSync.pytest async def test__register_instance(self): diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index 04ce7dd5b..f5cf34ab6 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -45,8 +45,12 @@ str_val, ) from google.api_core import grpc_helpers +from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( + _ReplaceableChannel, +) CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers) +CrossSync._Sync_Impl.add_mapping("ReplaceableChannel", _ReplaceableChannel) @CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") @@ -176,11 +180,14 @@ def test__start_background_channel_refresh_task_exists(self): client.close() def test__start_background_channel_refresh(self): - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") with mock.patch.object( client, "_ping_and_warm_instances", CrossSync._Sync_Impl.Mock() ) as ping_and_warm: client._emulator_host = None + client.transport._grpc_channel = CrossSync._Sync_Impl.ReplaceableChannel( + mock.Mock + ) client._start_background_channel_refresh() assert client._channel_refresh_task is not None assert isinstance(client._channel_refresh_task, CrossSync._Sync_Impl.Task) @@ -282,7 +289,7 @@ def test__manage_channel_first_sleep( with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: sleep.side_effect = asyncio.CancelledError try: - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") client._channel_init_time = -wait_time client._manage_channel(refresh_interval, refresh_interval) except asyncio.CancelledError: @@ -364,7 +371,7 @@ def test__manage_channel_random(self): uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -389,25 +396,28 @@ def test__manage_channel_refresh(self, num_cycles): expected_refresh = 0.5 grpc_lib = grpc.aio if CrossSync._Sync_Impl.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") + create_channel_mock = mock.Mock() + create_channel_mock.return_value = new_channel + refreshable_channel = CrossSync._Sync_Impl.ReplaceableChannel( + create_channel_mock + ) with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [RuntimeError] - with mock.patch.object( - CrossSync._Sync_Impl.grpc_helpers, "create_channel" - ) as create_channel: - create_channel.return_value = new_channel - client = self._make_client(project="project-id", use_emulator=False) - create_channel.reset_mock() - try: - client._manage_channel( - refresh_interval_min=expected_refresh, - refresh_interval_max=expected_refresh, - grace_period=0, - ) - except RuntimeError: - pass - assert sleep.call_count == num_cycles + 1 - assert create_channel.call_count == num_cycles - client.close() + client = self._make_client(project="project-id") + client.transport._grpc_channel = refreshable_channel + create_channel_mock.reset_mock() + sleep.reset_mock() + try: + client._manage_channel( + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=0, + ) + except RuntimeError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel_mock.call_count == num_cycles + client.close() def test__register_instance(self): """test instance registration""" From 7d90a04815d8e831679ffc24e287c6c9a1e722e7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 13:25:32 -0700 Subject: [PATCH 09/60] remove extra wrapper; added invalidate_stubs helper --- .../data/_async/_replaceable_channel.py | 95 +++---------------- google/cloud/bigtable/data/_async/client.py | 6 ++ .../_sync_autogen/_replaceable_channel.py | 80 +++------------- .../bigtable/data/_sync_autogen/client.py | 6 ++ 4 files changed, 34 insertions(+), 153 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_replaceable_channel.py b/google/cloud/bigtable/data/_async/_replaceable_channel.py index e1b404581..869867c5d 100644 --- a/google/cloud/bigtable/data/_async/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/_replaceable_channel.py @@ -21,69 +21,12 @@ from grpc import ChannelConnectivity if CrossSync.is_async: - from grpc.aio import Call from grpc.aio import Channel - from grpc.aio import UnaryUnaryMultiCallable - from grpc.aio import UnaryStreamMultiCallable - from grpc.aio import StreamUnaryMultiCallable - from grpc.aio import StreamStreamMultiCallable else: - from grpc import Call from grpc import Channel - from grpc import UnaryUnaryMultiCallable - from grpc import UnaryStreamMultiCallable - from grpc import StreamUnaryMultiCallable - from grpc import StreamStreamMultiCallable __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._replaceable_channel" -@CrossSync.convert_class -class _WrappedMultiCallable: - """ - Wrapper class that implements the grpc MultiCallable interface. - Allows generic functions that return calls to pass checks for - MultiCallable objects. - """ - - def __init__(self, call_factory: Callable[..., Call]): - self._call_factory = call_factory - - def __call__(self, *args, **kwargs) -> Call: - return self._call_factory(*args, **kwargs) - - -class WrappedUnaryUnaryMultiCallable( - _WrappedMultiCallable, UnaryUnaryMultiCallable -): - if not CrossSync.is_async: - # add missing functions for sync unary callable - - def with_call(self, *args, **kwargs): - call = self.__call__(self, *args, **kwargs) - return call(), call - - def future(self, *args, **kwargs): - raise NotImplementedError - - -class WrappedUnaryStreamMultiCallable( - _WrappedMultiCallable, UnaryStreamMultiCallable -): - pass - - -class WrappedStreamUnaryMultiCallable( - _WrappedMultiCallable, StreamUnaryMultiCallable -): - pass - - -class WrappedStreamStreamMultiCallable( - _WrappedMultiCallable, StreamStreamMultiCallable -): - pass - - @CrossSync.convert_class(sync_name="_WrappedChannel", rm_aio=True) class _AsyncWrappedChannel(Channel): """ @@ -94,33 +37,17 @@ class _AsyncWrappedChannel(Channel): def __init__(self, channel: Channel): self._channel = channel - def unary_unary(self, *args, **kwargs) -> UnaryUnaryMultiCallable: - return WrappedUnaryUnaryMultiCallable( - lambda *call_args, **call_kwargs: self._channel.unary_unary( - *args, **kwargs - )(*call_args, **call_kwargs) - ) - - def unary_stream(self, *args, **kwargs) -> UnaryStreamMultiCallable: - return WrappedUnaryStreamMultiCallable( - lambda *call_args, **call_kwargs: self._channel.unary_stream( - *args, **kwargs - )(*call_args, **call_kwargs) - ) - - def stream_unary(self, *args, **kwargs) -> StreamUnaryMultiCallable: - return WrappedStreamUnaryMultiCallable( - lambda *call_args, **call_kwargs: self._channel.stream_unary( - *args, **kwargs - )(*call_args, **call_kwargs) - ) - - def stream_stream(self, *args, **kwargs) -> StreamStreamMultiCallable: - return WrappedStreamStreamMultiCallable( - lambda *call_args, **call_kwargs: self._channel.stream_stream( - *args, **kwargs - )(*call_args, **call_kwargs) - ) + def unary_unary(self, *args, **kwargs): + return self._channel.unary_unary(*args, **kwargs) + + def unary_stream(self, *args, **kwargs): + return self._channel.unary_stream(*args, **kwargs) + + def stream_unary(self, *args, **kwargs): + return self._channel.stream_unary(*args, **kwargs) + + def stream_stream(self, *args, **kwargs): + return self._channel.stream_stream(*args, **kwargs) # grace not supported by sync version @CrossSync.drop diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index a39a2ff85..6d0c27ce7 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -346,6 +346,11 @@ async def _ping_and_warm_instances( ) return [r or None for r in result_list] + def _invalidate_channel_stubs(self): + """Helper to reset the cached stubs. Needed when changing out the grpc channel""" + self.transport._stubs = {} + self.transport._prep_wrapped_messages(self.client_info) + @CrossSync.convert(replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"}) async def _manage_channel( self, @@ -398,6 +403,7 @@ async def _manage_channel( await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure old_channel = super_channel.replace_wrapped_channel(new_channel) + self._invalidate_channel_stubs() # give old_channel a chance to complete existing rpcs if CrossSync.is_async: await old_channel.close(grace_period) diff --git a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py index c150f4f7f..3e71c3f6d 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py @@ -18,49 +18,7 @@ from __future__ import annotations from typing import Callable from grpc import ChannelConnectivity -from grpc import Call from grpc import Channel -from grpc import UnaryUnaryMultiCallable -from grpc import UnaryStreamMultiCallable -from grpc import StreamUnaryMultiCallable -from grpc import StreamStreamMultiCallable - - -class _WrappedMultiCallable: - """ - Wrapper class that implements the grpc MultiCallable interface. - Allows generic functions that return calls to pass checks for - MultiCallable objects. - """ - - def __init__(self, call_factory: Callable[..., Call]): - self._call_factory = call_factory - - def __call__(self, *args, **kwargs) -> Call: - return self._call_factory(*args, **kwargs) - - -class WrappedUnaryUnaryMultiCallable(_WrappedMultiCallable, UnaryUnaryMultiCallable): - def with_call(self, *args, **kwargs): - call = self.__call__(self, *args, **kwargs) - return (call(), call) - - def future(self, *args, **kwargs): - raise NotImplementedError - - -class WrappedUnaryStreamMultiCallable(_WrappedMultiCallable, UnaryStreamMultiCallable): - pass - - -class WrappedStreamUnaryMultiCallable(_WrappedMultiCallable, StreamUnaryMultiCallable): - pass - - -class WrappedStreamStreamMultiCallable( - _WrappedMultiCallable, StreamStreamMultiCallable -): - pass class _WrappedChannel(Channel): @@ -72,33 +30,17 @@ class _WrappedChannel(Channel): def __init__(self, channel: Channel): self._channel = channel - def unary_unary(self, *args, **kwargs) -> UnaryUnaryMultiCallable: - return WrappedUnaryUnaryMultiCallable( - lambda *call_args, **call_kwargs: self._channel.unary_unary( - *args, **kwargs - )(*call_args, **call_kwargs) - ) - - def unary_stream(self, *args, **kwargs) -> UnaryStreamMultiCallable: - return WrappedUnaryStreamMultiCallable( - lambda *call_args, **call_kwargs: self._channel.unary_stream( - *args, **kwargs - )(*call_args, **call_kwargs) - ) - - def stream_unary(self, *args, **kwargs) -> StreamUnaryMultiCallable: - return WrappedStreamUnaryMultiCallable( - lambda *call_args, **call_kwargs: self._channel.stream_unary( - *args, **kwargs - )(*call_args, **call_kwargs) - ) - - def stream_stream(self, *args, **kwargs) -> StreamStreamMultiCallable: - return WrappedStreamStreamMultiCallable( - lambda *call_args, **call_kwargs: self._channel.stream_stream( - *args, **kwargs - )(*call_args, **call_kwargs) - ) + def unary_unary(self, *args, **kwargs): + return self._channel.unary_unary(*args, **kwargs) + + def unary_stream(self, *args, **kwargs): + return self._channel.unary_stream(*args, **kwargs) + + def stream_unary(self, *args, **kwargs): + return self._channel.stream_unary(*args, **kwargs) + + def stream_stream(self, *args, **kwargs): + return self._channel.stream_stream(*args, **kwargs) def channel_ready(self): return self._channel.channel_ready() diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 6be313720..29c6159f0 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -262,6 +262,11 @@ def _ping_and_warm_instances( ) return [r or None for r in result_list] + def _invalidate_channel_stubs(self): + """Helper to reset the cached stubs. Needed when changing out the grpc channel""" + self.transport._stubs = {} + self.transport._prep_wrapped_messages(self.client_info) + def _manage_channel( self, refresh_interval_min: float = 60 * 35, @@ -304,6 +309,7 @@ def _manage_channel( new_channel = super_channel.create_channel() self._ping_and_warm_instances(channel=new_channel) old_channel = super_channel.replace_wrapped_channel(new_channel) + self._invalidate_channel_stubs() if grace_period: self._is_closed.wait(grace_period) old_channel.close() From 26cd6019281c407606b42d29b944b6eb1c4a9ee2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 13:53:19 -0700 Subject: [PATCH 10/60] fixed lint --- .../data/_async/_replaceable_channel.py | 42 ++++++++++++------- google/cloud/bigtable/data/_async/client.py | 21 ++++++---- tests/unit/data/_async/test_client.py | 18 +++----- 3 files changed, 46 insertions(+), 35 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_replaceable_channel.py b/google/cloud/bigtable/data/_async/_replaceable_channel.py index 869867c5d..e64052aa5 100644 --- a/google/cloud/bigtable/data/_async/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/_replaceable_channel.py @@ -27,6 +27,7 @@ __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._replaceable_channel" + @CrossSync.convert_class(sync_name="_WrappedChannel", rm_aio=True) class _AsyncWrappedChannel(Channel): """ @@ -49,15 +50,12 @@ def stream_unary(self, *args, **kwargs): def stream_stream(self, *args, **kwargs): return self._channel.stream_stream(*args, **kwargs) - # grace not supported by sync version - @CrossSync.drop - async def close(self, grace=None): - return await self._channel.close(grace=grace) - async def channel_ready(self): return await self._channel.channel_ready() - @CrossSync.convert(sync_name="__enter__", replace_symbols={"__aenter__": "__enter__"}) + @CrossSync.convert( + sync_name="__enter__", replace_symbols={"__aenter__": "__enter__"} + ) async def __aenter__(self): await self._channel.__aenter__() return self @@ -75,7 +73,13 @@ async def wait_for_state_change(self, last_observed_state): def __getattr__(self, name): return getattr(self._channel, name) - if not CrossSync.is_async: + if CrossSync.is_async: + # grace not supported by sync version + async def close(self, grace=None): + return await self._channel.close(grace=grace) + + else: + # add required sync methods def close(self): return self._channel.close() @@ -87,9 +91,11 @@ def unsubscribe(self, callback): return self._channel.unsubscribe(callback) -@CrossSync.convert_class(sync_name="_ReplaceableChannel", replace_symbols={"_AsyncWrappedChannel": "_WrappedChannel"}) +@CrossSync.convert_class( + sync_name="_ReplaceableChannel", + replace_symbols={"_AsyncWrappedChannel": "_WrappedChannel"}, +) class _AsyncReplaceableChannel(_AsyncWrappedChannel): - def __init__(self, channel_fn: Callable[[], Channel]): self._channel_fn = channel_fn self._channel = channel_fn() @@ -100,13 +106,21 @@ def create_channel(self) -> Channel: # copy over interceptors # this is needed because of how gapic attaches the LoggingClientAIOInterceptor # sync channels add interceptors by wrapping, so this step isn't needed - new_channel._unary_unary_interceptors = self._channel._unary_unary_interceptors - new_channel._unary_stream_interceptors = self._channel._unary_stream_interceptors - new_channel._stream_unary_interceptors = self._channel._stream_unary_interceptors - new_channel._stream_stream_interceptors = self._channel._stream_stream_interceptors + new_channel._unary_unary_interceptors = ( + self._channel._unary_unary_interceptors + ) + new_channel._unary_stream_interceptors = ( + self._channel._unary_stream_interceptors + ) + new_channel._stream_unary_interceptors = ( + self._channel._stream_unary_interceptors + ) + new_channel._stream_stream_interceptors = ( + self._channel._stream_stream_interceptors + ) return new_channel def replace_wrapped_channel(self, new_channel: Channel) -> Channel: old_channel = self._channel self._channel = new_channel - return old_channel \ No newline at end of file + return old_channel diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6d0c27ce7..6360a0c52 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -92,18 +92,18 @@ from google.cloud.bigtable_v2.services.bigtable.transports import ( BigtableGrpcAsyncIOTransport as TransportType, ) - from google.cloud.bigtable_v2.services.bigtable.transports.grpc_asyncio import ( - _LoggingClientAIOInterceptor - ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._async._replaceable_channel import _AsyncReplaceableChannel + from google.cloud.bigtable.data._async._replaceable_channel import ( + _AsyncReplaceableChannel, + ) else: from typing import Iterable # noqa: F401 from grpc import insecure_channel - from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._sync_autogen._replaceable_channel import _ReplaceableChannel + from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( # noqa: F401 + _ReplaceableChannel, + ) if TYPE_CHECKING: @@ -239,7 +239,9 @@ def __init__( stacklevel=2, ) - @CrossSync.convert(replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"}) + @CrossSync.convert( + replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"} + ) def _build_grpc_channel(self, *args, **kwargs) -> _AsyncReplaceableChannel: if self._emulator_host is not None: # emulators use insecure channel @@ -248,7 +250,6 @@ def _build_grpc_channel(self, *args, **kwargs) -> _AsyncReplaceableChannel: create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) return _AsyncReplaceableChannel(create_channel_fn) - @staticmethod def _client_version() -> str: """ @@ -351,7 +352,9 @@ def _invalidate_channel_stubs(self): self.transport._stubs = {} self.transport._prep_wrapped_messages(self.client_info) - @CrossSync.convert(replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"}) + @CrossSync.convert( + replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"} + ) async def _manage_channel( self, refresh_interval_min: float = 60 * 35, diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 23abeb633..8a57614f3 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -51,14 +51,18 @@ if CrossSync.is_async: from google.api_core import grpc_helpers_async from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async._replaceable_channel import _AsyncReplaceableChannel + from google.cloud.bigtable.data._async._replaceable_channel import ( + _AsyncReplaceableChannel, + ) CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) CrossSync.add_mapping("ReplaceableChannel", _AsyncReplaceableChannel) else: from google.api_core import grpc_helpers from google.cloud.bigtable.data._sync_autogen.client import Table # noqa: F401 - from google.cloud.bigtable.data._sync_autogen._replaceable_channel import _ReplaceableChannel + from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( + _ReplaceableChannel, + ) CrossSync.add_mapping("grpc_helpers", grpc_helpers) CrossSync.add_mapping("ReplaceableChannel", _ReplaceableChannel) @@ -388,18 +392,8 @@ async def test__manage_channel_ping_and_warm(self): """ _manage channel should call ping and warm internally """ - import time import threading - if CrossSync.is_async: - from google.cloud.bigtable_v2.services.bigtable.transports.grpc_asyncio import ( - _LoggingClientAIOInterceptor as Interceptor, - ) - else: - from google.cloud.bigtable_v2.services.bigtable.transports.grpc import ( - _LoggingClientInterceptor as Interceptor, - ) - client = self._make_client(project="project-id", use_emulator=True) orig_channel = client.transport.grpc_channel # should ping an warm all new channels, and old channels if sleeping From 375332fa5fec2d78885433c309b86d2379f24c46 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 14:32:05 -0700 Subject: [PATCH 11/60] fixed lint --- .../data/_async/_replaceable_channel.py | 13 ++++--- google/cloud/bigtable/data/_async/client.py | 36 ++++++++++++------- .../_sync_autogen/_replaceable_channel.py | 2 +- .../bigtable/data/_sync_autogen/client.py | 33 ++++++++++------- .../_async/execute_query_iterator.py | 4 +-- .../_sync_autogen/execute_query_iterator.py | 4 +-- 6 files changed, 56 insertions(+), 36 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_replaceable_channel.py b/google/cloud/bigtable/data/_async/_replaceable_channel.py index e64052aa5..b1e2c06ce 100644 --- a/google/cloud/bigtable/data/_async/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/_replaceable_channel.py @@ -73,17 +73,16 @@ async def wait_for_state_change(self, last_observed_state): def __getattr__(self, name): return getattr(self._channel, name) - if CrossSync.is_async: - # grace not supported by sync version - async def close(self, grace=None): + async def close(self, grace=None): + if CrossSync.is_async: return await self._channel.close(grace=grace) + else: + # grace not supported by sync version + return self._channel.close() - else: + if not CrossSync.is_async: # add required sync methods - def close(self): - return self._channel.close() - def subscribe(self, callback, try_to_connect=False): return self._channel.subscribe(callback, try_to_connect) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6360a0c52..6d1e9c789 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -92,6 +92,9 @@ from google.cloud.bigtable_v2.services.bigtable.transports import ( BigtableGrpcAsyncIOTransport as TransportType, ) + from google.cloud.bigtable_v2.services.bigtable import ( + BigtableAsyncClient as GapicClient, + ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._async._replaceable_channel import ( _AsyncReplaceableChannel, @@ -100,6 +103,7 @@ from typing import Iterable # noqa: F401 from grpc import insecure_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore + from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( # noqa: F401 _ReplaceableChannel, @@ -207,7 +211,7 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = CrossSync.GapicClient( + self._gapic_client = GapicClient( credentials=credentials, client_options=client_options, client_info=self.client_info, @@ -224,7 +228,7 @@ def __init__( self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_task: CrossSync.Task[None] | None = None - self._executor = ( + self._executor: concurrent.futures.ThreadPoolExecutor | None = ( concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None ) if self._emulator_host is None: @@ -876,24 +880,32 @@ def __init__( self.table_name = self.client._gapic_client.table_path( self.client.project, instance_id, table_id ) - self.app_profile_id = app_profile_id + self.app_profile_id: str | None = app_profile_id - self.default_operation_timeout = default_operation_timeout - self.default_attempt_timeout = default_attempt_timeout - self.default_read_rows_operation_timeout = default_read_rows_operation_timeout - self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout - self.default_mutate_rows_operation_timeout = ( + self.default_operation_timeout: float = default_operation_timeout + self.default_attempt_timeout: float | None = default_attempt_timeout + self.default_read_rows_operation_timeout: float = ( + default_read_rows_operation_timeout + ) + self.default_read_rows_attempt_timeout: float | None = ( + default_read_rows_attempt_timeout + ) + self.default_mutate_rows_operation_timeout: float = ( default_mutate_rows_operation_timeout ) - self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout + self.default_mutate_rows_attempt_timeout: float | None = ( + default_mutate_rows_attempt_timeout + ) - self.default_read_rows_retryable_errors = ( + self.default_read_rows_retryable_errors: Sequence[type[Exception]] = ( default_read_rows_retryable_errors or () ) - self.default_mutate_rows_retryable_errors = ( + self.default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( default_mutate_rows_retryable_errors or () ) - self.default_retryable_errors = default_retryable_errors or () + self.default_retryable_errors: Sequence[type[Exception]] = ( + default_retryable_errors or () + ) try: self._register_instance_future = CrossSync.create_task( diff --git a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py index 3e71c3f6d..473a19710 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py @@ -61,7 +61,7 @@ def wait_for_state_change(self, last_observed_state): def __getattr__(self, name): return getattr(self._channel, name) - def close(self): + def close(self, grace=None): return self._channel.close() def subscribe(self, callback, try_to_connect=False): diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 29c6159f0..decf91832 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -78,6 +78,7 @@ from google.cloud.bigtable_v2.services.bigtable.transports import ( BigtableGrpcTransport as TransportType, ) +from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( _ReplaceableChannel, @@ -150,7 +151,7 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = CrossSync._Sync_Impl.GapicClient( + self._gapic_client = GapicClient( credentials=credentials, client_options=client_options, client_info=self.client_info, @@ -164,7 +165,7 @@ def __init__( self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_task: CrossSync._Sync_Impl.Task[None] | None = None - self._executor = ( + self._executor: concurrent.futures.ThreadPoolExecutor | None = ( concurrent.futures.ThreadPoolExecutor() if not CrossSync._Sync_Impl.is_async else None @@ -682,22 +683,30 @@ def __init__( self.table_name = self.client._gapic_client.table_path( self.client.project, instance_id, table_id ) - self.app_profile_id = app_profile_id - self.default_operation_timeout = default_operation_timeout - self.default_attempt_timeout = default_attempt_timeout - self.default_read_rows_operation_timeout = default_read_rows_operation_timeout - self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout - self.default_mutate_rows_operation_timeout = ( + self.app_profile_id: str | None = app_profile_id + self.default_operation_timeout: float = default_operation_timeout + self.default_attempt_timeout: float | None = default_attempt_timeout + self.default_read_rows_operation_timeout: float = ( + default_read_rows_operation_timeout + ) + self.default_read_rows_attempt_timeout: float | None = ( + default_read_rows_attempt_timeout + ) + self.default_mutate_rows_operation_timeout: float = ( default_mutate_rows_operation_timeout ) - self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout - self.default_read_rows_retryable_errors = ( + self.default_mutate_rows_attempt_timeout: float | None = ( + default_mutate_rows_attempt_timeout + ) + self.default_read_rows_retryable_errors: Sequence[type[Exception]] = ( default_read_rows_retryable_errors or () ) - self.default_mutate_rows_retryable_errors = ( + self.default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( default_mutate_rows_retryable_errors or () ) - self.default_retryable_errors = default_retryable_errors or () + self.default_retryable_errors: Sequence[type[Exception]] = ( + default_retryable_errors or () + ) try: self._register_instance_future = CrossSync._Sync_Impl.create_task( self.client._register_instance, diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index d3ca890b4..74f01c60c 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -115,8 +115,8 @@ def __init__( self._app_profile_id = app_profile_id self._client = client self._instance_id = instance_id - self._prepare_metadata = prepare_metadata - self._final_metadata = None + self._prepare_metadata: Metadata = prepare_metadata + self._final_metadata: Metadata | None = None self._byte_cursor = _ByteCursor() self._reader: _Reader[QueryResultRow] = _QueryResultRowReader() self.has_received_token = False diff --git a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py index 9c2d1c6d8..e819acda7 100644 --- a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py @@ -90,8 +90,8 @@ def __init__( self._app_profile_id = app_profile_id self._client = client self._instance_id = instance_id - self._prepare_metadata = prepare_metadata - self._final_metadata = None + self._prepare_metadata: Metadata = prepare_metadata + self._final_metadata: Metadata | None = None self._byte_cursor = _ByteCursor() self._reader: _Reader[QueryResultRow] = _QueryResultRowReader() self.has_received_token = False From 428d75a5cf5ebf77c87f2dbe833ab533bb71ff09 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 15:00:58 -0700 Subject: [PATCH 12/60] renamed replaceablechannel to swappablechannel --- ...eable_channel.py => _swappable_channel.py} | 8 +++---- google/cloud/bigtable/data/_async/client.py | 22 +++++++++---------- ...eable_channel.py => _swappable_channel.py} | 4 ++-- .../bigtable/data/_sync_autogen/client.py | 14 +++++------- tests/unit/data/_async/test_client.py | 16 +++++++------- tests/unit/data/_sync_autogen/test_client.py | 12 ++++------ 6 files changed, 35 insertions(+), 41 deletions(-) rename google/cloud/bigtable/data/_async/{_replaceable_channel.py => _swappable_channel.py} (95%) rename google/cloud/bigtable/data/_sync_autogen/{_replaceable_channel.py => _swappable_channel.py} (95%) diff --git a/google/cloud/bigtable/data/_async/_replaceable_channel.py b/google/cloud/bigtable/data/_async/_swappable_channel.py similarity index 95% rename from google/cloud/bigtable/data/_async/_replaceable_channel.py rename to google/cloud/bigtable/data/_async/_swappable_channel.py index b1e2c06ce..97fb855dd 100644 --- a/google/cloud/bigtable/data/_async/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_async/_swappable_channel.py @@ -25,7 +25,7 @@ else: from grpc import Channel -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._replaceable_channel" +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._swappable_channel" @CrossSync.convert_class(sync_name="_WrappedChannel", rm_aio=True) @@ -91,10 +91,10 @@ def unsubscribe(self, callback): @CrossSync.convert_class( - sync_name="_ReplaceableChannel", + sync_name="SwappableChannel", replace_symbols={"_AsyncWrappedChannel": "_WrappedChannel"}, ) -class _AsyncReplaceableChannel(_AsyncWrappedChannel): +class AsyncSwappableChannel(_AsyncWrappedChannel): def __init__(self, channel_fn: Callable[[], Channel]): self._channel_fn = channel_fn self._channel = channel_fn() @@ -119,7 +119,7 @@ def create_channel(self) -> Channel: ) return new_channel - def replace_wrapped_channel(self, new_channel: Channel) -> Channel: + def swap_channel(self, new_channel: Channel) -> Channel: old_channel = self._channel self._channel = new_channel return old_channel diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6d1e9c789..674680790 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -96,8 +96,8 @@ BigtableAsyncClient as GapicClient, ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._async._replaceable_channel import ( - _AsyncReplaceableChannel, + from google.cloud.bigtable.data._async._swappable_channel import ( + AsyncSwappableChannel, ) else: from typing import Iterable # noqa: F401 @@ -105,8 +105,8 @@ from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( # noqa: F401 - _ReplaceableChannel, + from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( # noqa: F401 + SwappableChannel, ) @@ -244,15 +244,15 @@ def __init__( ) @CrossSync.convert( - replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"} + replace_symbols={"AsyncSwappableChannel": "SwappableChannel"} ) - def _build_grpc_channel(self, *args, **kwargs) -> _AsyncReplaceableChannel: + def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel: if self._emulator_host is not None: # emulators use insecure channel create_channel_fn = partial(insecure_channel, self._emulator_host) else: create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) - return _AsyncReplaceableChannel(create_channel_fn) + return AsyncSwappableChannel(create_channel_fn) @staticmethod def _client_version() -> str: @@ -357,7 +357,7 @@ def _invalidate_channel_stubs(self): self.transport._prep_wrapped_messages(self.client_info) @CrossSync.convert( - replace_symbols={"_AsyncReplaceableChannel": "_ReplaceableChannel"} + replace_symbols={"AsyncSwappableChannel": "SwappableChannel"} ) async def _manage_channel( self, @@ -383,10 +383,10 @@ async def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ - if not isinstance(self.transport.grpc_channel, _AsyncReplaceableChannel): + if not isinstance(self.transport.grpc_channel, AsyncSwappableChannel): warnings.warn("Channel does not support auto-refresh.") return - super_channel: _AsyncReplaceableChannel = self.transport.grpc_channel + super_channel: AsyncSwappableChannel = self.transport.grpc_channel first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -409,7 +409,7 @@ async def _manage_channel( new_channel = super_channel.create_channel() await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure - old_channel = super_channel.replace_wrapped_channel(new_channel) + old_channel = super_channel.swap_channel(new_channel) self._invalidate_channel_stubs() # give old_channel a chance to complete existing rpcs if CrossSync.is_async: diff --git a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py b/google/cloud/bigtable/data/_sync_autogen/_swappable_channel.py similarity index 95% rename from google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py rename to google/cloud/bigtable/data/_sync_autogen/_swappable_channel.py index 473a19710..306891776 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_replaceable_channel.py +++ b/google/cloud/bigtable/data/_sync_autogen/_swappable_channel.py @@ -71,7 +71,7 @@ def unsubscribe(self, callback): return self._channel.unsubscribe(callback) -class _ReplaceableChannel(_WrappedChannel): +class SwappableChannel(_WrappedChannel): def __init__(self, channel_fn: Callable[[], Channel]): self._channel_fn = channel_fn self._channel = channel_fn() @@ -80,7 +80,7 @@ def create_channel(self) -> Channel: new_channel = self._channel_fn() return new_channel - def replace_wrapped_channel(self, new_channel: Channel) -> Channel: + def swap_channel(self, new_channel: Channel) -> Channel: old_channel = self._channel self._channel = new_channel return old_channel diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index decf91832..ec13406ee 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -80,9 +80,7 @@ ) from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE -from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( - _ReplaceableChannel, -) +from google.cloud.bigtable.data._sync_autogen._swappable_channel import SwappableChannel if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -180,12 +178,12 @@ def __init__( stacklevel=2, ) - def _build_grpc_channel(self, *args, **kwargs) -> _ReplaceableChannel: + def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel: if self._emulator_host is not None: create_channel_fn = partial(insecure_channel, self._emulator_host) else: create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) - return _ReplaceableChannel(create_channel_fn) + return SwappableChannel(create_channel_fn) @staticmethod def _client_version() -> str: @@ -290,10 +288,10 @@ def _manage_channel( between `refresh_interval_min` and `refresh_interval_max` grace_period: time to allow previous channel to serve existing requests before closing, in seconds""" - if not isinstance(self.transport.grpc_channel, _ReplaceableChannel): + if not isinstance(self.transport.grpc_channel, SwappableChannel): warnings.warn("Channel does not support auto-refresh.") return - super_channel: _ReplaceableChannel = self.transport.grpc_channel + super_channel: SwappableChannel = self.transport.grpc_channel first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -309,7 +307,7 @@ def _manage_channel( start_timestamp = time.monotonic() new_channel = super_channel.create_channel() self._ping_and_warm_instances(channel=new_channel) - old_channel = super_channel.replace_wrapped_channel(new_channel) + old_channel = super_channel.swap_channel(new_channel) self._invalidate_channel_stubs() if grace_period: self._is_closed.wait(grace_period) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 8a57614f3..be3d149d0 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -51,21 +51,21 @@ if CrossSync.is_async: from google.api_core import grpc_helpers_async from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async._replaceable_channel import ( - _AsyncReplaceableChannel, + from google.cloud.bigtable.data._async._swappable_channel import ( + AsyncSwappableChannel, ) CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) - CrossSync.add_mapping("ReplaceableChannel", _AsyncReplaceableChannel) + CrossSync.add_mapping("SwappableChannel", AsyncSwappableChannel) else: from google.api_core import grpc_helpers from google.cloud.bigtable.data._sync_autogen.client import Table # noqa: F401 - from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( - _ReplaceableChannel, + from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( + SwappableChannel, ) CrossSync.add_mapping("grpc_helpers", grpc_helpers) - CrossSync.add_mapping("ReplaceableChannel", _ReplaceableChannel) + CrossSync.add_mapping("SwappableChannel", SwappableChannel) __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_client" @@ -236,7 +236,7 @@ async def test__start_background_channel_refresh(self): client, "_ping_and_warm_instances", CrossSync.Mock() ) as ping_and_warm: client._emulator_host = None - client.transport._grpc_channel = CrossSync.ReplaceableChannel(mock.Mock) + client.transport._grpc_channel = CrossSync.SwappableChannel(mock.Mock) client._start_background_channel_refresh() assert client._channel_refresh_task is not None assert isinstance(client._channel_refresh_task, CrossSync.Task) @@ -500,7 +500,7 @@ async def test__manage_channel_refresh(self, num_cycles): new_channel = grpc_lib.insecure_channel("localhost:8080") create_channel_mock = mock.Mock() create_channel_mock.return_value = new_channel - refreshable_channel = CrossSync.ReplaceableChannel(create_channel_mock) + refreshable_channel = CrossSync.SwappableChannel(create_channel_mock) with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [RuntimeError] diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index f5cf34ab6..36146d6ee 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -45,12 +45,10 @@ str_val, ) from google.api_core import grpc_helpers -from google.cloud.bigtable.data._sync_autogen._replaceable_channel import ( - _ReplaceableChannel, -) +from google.cloud.bigtable.data._sync_autogen._swappable_channel import SwappableChannel CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers) -CrossSync._Sync_Impl.add_mapping("ReplaceableChannel", _ReplaceableChannel) +CrossSync._Sync_Impl.add_mapping("SwappableChannel", SwappableChannel) @CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") @@ -185,7 +183,7 @@ def test__start_background_channel_refresh(self): client, "_ping_and_warm_instances", CrossSync._Sync_Impl.Mock() ) as ping_and_warm: client._emulator_host = None - client.transport._grpc_channel = CrossSync._Sync_Impl.ReplaceableChannel( + client.transport._grpc_channel = CrossSync._Sync_Impl.SwappableChannel( mock.Mock ) client._start_background_channel_refresh() @@ -398,9 +396,7 @@ def test__manage_channel_refresh(self, num_cycles): new_channel = grpc_lib.insecure_channel("localhost:8080") create_channel_mock = mock.Mock() create_channel_mock.return_value = new_channel - refreshable_channel = CrossSync._Sync_Impl.ReplaceableChannel( - create_channel_mock - ) + refreshable_channel = CrossSync._Sync_Impl.SwappableChannel(create_channel_mock) with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [RuntimeError] client = self._make_client(project="project-id") From 4b39bc5b005faffb4940b591fcb6963ba18b43c0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 15:39:58 -0700 Subject: [PATCH 13/60] added tests --- google/cloud/bigtable/data/_async/client.py | 8 +- tests/system/data/test_system_async.py | 18 +++ tests/system/data/test_system_autogen.py | 6 + .../data/_async/test__swappable_channel.py | 135 ++++++++++++++++++ .../_sync_autogen/test__swappable_channel.py | 100 +++++++++++++ 5 files changed, 261 insertions(+), 6 deletions(-) create mode 100644 tests/unit/data/_async/test__swappable_channel.py create mode 100644 tests/unit/data/_sync_autogen/test__swappable_channel.py diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 674680790..e80c37d7d 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -243,9 +243,7 @@ def __init__( stacklevel=2, ) - @CrossSync.convert( - replace_symbols={"AsyncSwappableChannel": "SwappableChannel"} - ) + @CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"}) def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel: if self._emulator_host is not None: # emulators use insecure channel @@ -356,9 +354,7 @@ def _invalidate_channel_stubs(self): self.transport._stubs = {} self.transport._prep_wrapped_messages(self.client_info) - @CrossSync.convert( - replace_symbols={"AsyncSwappableChannel": "SwappableChannel"} - ) + @CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"}) async def _manage_channel( self, refresh_interval_min: float = 60 * 35, diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 9f4fa7abb..b4350d0ab 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -29,6 +29,14 @@ from . import TEST_FAMILY, TEST_FAMILY_2 +if CrossSync.is_async: + from google.cloud.bigtable_v2.services.bigtable.transports.grpc_asyncio import ( + _LoggingClientAIOInterceptor as GapicInterceptor, + ) +else: + from google.cloud.bigtable_v2.services.bigtable.transports.grpc import ( + _LoggingClientInterceptor as GapicInterceptor, + ) __CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system_autogen" @@ -260,6 +268,16 @@ async def test_channel_refresh(self, table_id, instance_id, temp_rows): assert len(rows_after_refresh) == 2 assert client.transport.grpc_channel is channel_wrapper assert client.transport.grpc_channel._channel is not first_channel + # ensure gapic's logging interceptor is still active + if CrossSync.is_async: + interceptors = ( + client.transport.grpc_channel._channel._unary_unary_interceptors + ) + assert GapicInterceptor in [type(i) for i in interceptors] + else: + assert isinstance( + client.transport._logged_channel._interceptor, GapicInterceptor + ) finally: await client.close() diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 7d99805b7..ad77c7d3e 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -27,6 +27,9 @@ from google.type import date_pb2 from google.cloud.bigtable.data._cross_sync import CrossSync from . import TEST_FAMILY, TEST_FAMILY_2 +from google.cloud.bigtable_v2.services.bigtable.transports.grpc import ( + _LoggingClientInterceptor as GapicInterceptor, +) TARGETS = ["table"] if not os.environ.get(BIGTABLE_EMULATOR): @@ -209,6 +212,9 @@ def test_channel_refresh(self, table_id, instance_id, temp_rows): assert len(rows_after_refresh) == 2 assert client.transport.grpc_channel is channel_wrapper assert client.transport.grpc_channel._channel is not first_channel + assert isinstance( + client.transport._logged_channel._interceptor, GapicInterceptor + ) finally: client.close() diff --git a/tests/unit/data/_async/test__swappable_channel.py b/tests/unit/data/_async/test__swappable_channel.py new file mode 100644 index 000000000..14fef2c85 --- /dev/null +++ b/tests/unit/data/_async/test__swappable_channel.py @@ -0,0 +1,135 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore + +import pytest +from grpc import ChannelConnectivity + +from google.cloud.bigtable.data._cross_sync import CrossSync + +if CrossSync.is_async: + from google.cloud.bigtable.data._async._swappable_channel import ( + AsyncSwappableChannel as TargetType, + ) +else: + from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( + SwappableChannel as TargetType, + ) + + +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test__swappable_channel" + + +@CrossSync.convert_class(sync_name="TestSwappableChannel") +class TestAsyncSwappableChannel: + @staticmethod + @CrossSync.convert + def _get_target_class(): + return TargetType + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + assert instance._channel_fn == channel_fn + channel_fn.assert_called_once_with() + assert instance._channel == channel_fn.return_value + + def test_swap_channel(self): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + old_channel = instance._channel + new_channel = object() + result = instance.swap_channel(new_channel) + assert result == old_channel + assert instance._channel == new_channel + + def test_create_channel(self): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + # reset mock from ctor call + channel_fn.reset_mock() + new_channel = instance.create_channel() + channel_fn.assert_called_once_with() + assert new_channel == channel_fn.return_value + + @CrossSync.drop + def test_create_channel_async_interceptors_copied(self): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + # reset mock from ctor call + channel_fn.reset_mock() + # mock out interceptors on original channel + instance._channel._unary_unary_interceptors = ["unary_unary"] + instance._channel._unary_stream_interceptors = ["unary_stream"] + instance._channel._stream_unary_interceptors = ["stream_unary"] + instance._channel._stream_stream_interceptors = ["stream_stream"] + + new_channel = instance.create_channel() + channel_fn.assert_called_once_with() + assert new_channel == channel_fn.return_value + assert new_channel._unary_unary_interceptors == ["unary_unary"] + assert new_channel._unary_stream_interceptors == ["unary_stream"] + assert new_channel._stream_unary_interceptors == ["stream_unary"] + assert new_channel._stream_stream_interceptors == ["stream_stream"] + + @pytest.mark.parametrize( + "method_name,args,kwargs", + [ + ("unary_unary", (1,), {"kw": 2}), + ("unary_stream", (3,), {"kw": 4}), + ("stream_unary", (5,), {"kw": 6}), + ("stream_stream", (7,), {"kw": 8}), + ("get_state", (), {"try_to_connect": True}), + ], + ) + def test_forwarded_methods(self, method_name, args, kwargs): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + method = getattr(instance, method_name) + result = method(*args, **kwargs) + mock_method = getattr(channel_fn.return_value, method_name) + mock_method.assert_called_once_with(*args, **kwargs) + assert result == mock_method.return_value + + @pytest.mark.parametrize( + "method_name,args,kwargs", + [ + ("channel_ready", (), {}), + ("wait_for_state_change", (ChannelConnectivity.READY,), {}), + ], + ) + @CrossSync.pytest + async def test_forwarded_async_methods(self, method_name, args, kwargs): + async def dummy_coro(*a, **k): + return mock.sentinel.result + + channel = mock.Mock() + mock_method = getattr(channel, method_name) + mock_method.side_effect = dummy_coro + + channel_fn = mock.Mock(return_value=channel) + instance = self._make_one(channel_fn) + method = getattr(instance, method_name) + result = await method(*args, **kwargs) + + mock_method.assert_called_once_with(*args, **kwargs) + assert result == mock.sentinel.result diff --git a/tests/unit/data/_sync_autogen/test__swappable_channel.py b/tests/unit/data/_sync_autogen/test__swappable_channel.py new file mode 100644 index 000000000..04f3f61c8 --- /dev/null +++ b/tests/unit/data/_sync_autogen/test__swappable_channel.py @@ -0,0 +1,100 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# try/except added for compatibility with python < 3.8 + +# This file is automatically generated by CrossSync. Do not edit manually. + +try: + from unittest import mock +except ImportError: + import mock +import pytest +from grpc import ChannelConnectivity +from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( + SwappableChannel as TargetType, +) + + +class TestSwappableChannel: + @staticmethod + def _get_target_class(): + return TargetType + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + assert instance._channel_fn == channel_fn + channel_fn.assert_called_once_with() + assert instance._channel == channel_fn.return_value + + def test_swap_channel(self): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + old_channel = instance._channel + new_channel = object() + result = instance.swap_channel(new_channel) + assert result == old_channel + assert instance._channel == new_channel + + def test_create_channel(self): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + channel_fn.reset_mock() + new_channel = instance.create_channel() + channel_fn.assert_called_once_with() + assert new_channel == channel_fn.return_value + + @pytest.mark.parametrize( + "method_name,args,kwargs", + [ + ("unary_unary", (1,), {"kw": 2}), + ("unary_stream", (3,), {"kw": 4}), + ("stream_unary", (5,), {"kw": 6}), + ("stream_stream", (7,), {"kw": 8}), + ("get_state", (), {"try_to_connect": True}), + ], + ) + def test_forwarded_methods(self, method_name, args, kwargs): + channel_fn = mock.Mock() + instance = self._make_one(channel_fn) + method = getattr(instance, method_name) + result = method(*args, **kwargs) + mock_method = getattr(channel_fn.return_value, method_name) + mock_method.assert_called_once_with(*args, **kwargs) + assert result == mock_method.return_value + + @pytest.mark.parametrize( + "method_name,args,kwargs", + [ + ("channel_ready", (), {}), + ("wait_for_state_change", (ChannelConnectivity.READY,), {}), + ], + ) + def test_forwarded_async_methods(self, method_name, args, kwargs): + def dummy_coro(*a, **k): + return mock.sentinel.result + + channel = mock.Mock() + mock_method = getattr(channel, method_name) + mock_method.side_effect = dummy_coro + channel_fn = mock.Mock(return_value=channel) + instance = self._make_one(channel_fn) + method = getattr(instance, method_name) + result = method(*args, **kwargs) + mock_method.assert_called_once_with(*args, **kwargs) + assert result == mock.sentinel.result From 3f090c2896c33cbf94c0b21f8d1179be492e8ea9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 15:50:12 -0700 Subject: [PATCH 14/60] added docstrings --- .../bigtable/data/_async/_swappable_channel.py | 13 +++++++++++++ google/cloud/bigtable/data/_async/client.py | 14 ++++++++++++++ .../data/_sync_autogen/_swappable_channel.py | 10 ++++++++++ google/cloud/bigtable/data/_sync_autogen/client.py | 12 ++++++++++++ 4 files changed, 49 insertions(+) diff --git a/google/cloud/bigtable/data/_async/_swappable_channel.py b/google/cloud/bigtable/data/_async/_swappable_channel.py index 97fb855dd..e590e1234 100644 --- a/google/cloud/bigtable/data/_async/_swappable_channel.py +++ b/google/cloud/bigtable/data/_async/_swappable_channel.py @@ -95,11 +95,21 @@ def unsubscribe(self, callback): replace_symbols={"_AsyncWrappedChannel": "_WrappedChannel"}, ) class AsyncSwappableChannel(_AsyncWrappedChannel): + """ + Provides a grpc channel wrapper, that allows the internal channel to be swapped out + + Args: + - channel_fn: a nullary function that returns a new channel instance. + It should be a partial with all channel configuration arguments built-in + """ def __init__(self, channel_fn: Callable[[], Channel]): self._channel_fn = channel_fn self._channel = channel_fn() def create_channel(self) -> Channel: + """ + Create a fresh channel using the stored `channel_fn` partial + """ new_channel = self._channel_fn() if CrossSync.is_async: # copy over interceptors @@ -120,6 +130,9 @@ def create_channel(self) -> Channel: return new_channel def swap_channel(self, new_channel: Channel) -> Channel: + """ + Replace the wrapped channel with a new instance. Typically created using `create_channel` + """ old_channel = self._channel self._channel = new_channel return old_channel diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index e80c37d7d..34c0c9892 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -245,6 +245,20 @@ def __init__( @CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"}) def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel: + """ + This method is called by the gapic transport to create a grpc channel. + + The init arguments passed down are captured in a partial used by AsyncSwappableChannel + to create new channel instances in the future, as part of the channel refresh logic + + Emulators always use an inseucre channel + + Args: + - *args: positional arguments passed by the gapic layer to create a new channel with + - **kwargs: keyword arguments passed by the gapic layer to create a new channel with + Returns: + a custom wrapped swappable channel + """ if self._emulator_host is not None: # emulators use insecure channel create_channel_fn = partial(insecure_channel, self._emulator_host) diff --git a/google/cloud/bigtable/data/_sync_autogen/_swappable_channel.py b/google/cloud/bigtable/data/_sync_autogen/_swappable_channel.py index 306891776..78ba129d9 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_swappable_channel.py +++ b/google/cloud/bigtable/data/_sync_autogen/_swappable_channel.py @@ -72,15 +72,25 @@ def unsubscribe(self, callback): class SwappableChannel(_WrappedChannel): + """ + Provides a grpc channel wrapper, that allows the internal channel to be swapped out + + Args: + - channel_fn: a nullary function that returns a new channel instance. + It should be a partial with all channel configuration arguments built-in + """ + def __init__(self, channel_fn: Callable[[], Channel]): self._channel_fn = channel_fn self._channel = channel_fn() def create_channel(self) -> Channel: + """Create a fresh channel using the stored `channel_fn` partial""" new_channel = self._channel_fn() return new_channel def swap_channel(self, new_channel: Channel) -> Channel: + """Replace the wrapped channel with a new instance. Typically created using `create_channel`""" old_channel = self._channel self._channel = new_channel return old_channel diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index ec13406ee..dc211e195 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -179,6 +179,18 @@ def __init__( ) def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel: + """This method is called by the gapic transport to create a grpc channel. + + The init arguments passed down are captured in a partial used by SwappableChannel + to create new channel instances in the future, as part of the channel refresh logic + + Emulators always use an inseucre channel + + Args: + - *args: positional arguments passed by the gapic layer to create a new channel with + - **kwargs: keyword arguments passed by the gapic layer to create a new channel with + Returns: + a custom wrapped swappable channel""" if self._emulator_host is not None: create_channel_fn = partial(insecure_channel, self._emulator_host) else: From 04c762a05f1d5067225bd4450fc0c9547e4fb149 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 16:00:17 -0700 Subject: [PATCH 15/60] initial commit --- .../cloud/bigtable/data/_metrics/__init__.py | 25 + .../bigtable/data/_metrics/data_model.py | 446 +++++++++ .../bigtable/data/_metrics/handlers/_base.py | 35 + .../data/_metrics/metrics_controller.py | 72 ++ tests/unit/data/_metrics/__init__.py | 0 tests/unit/data/_metrics/test_data_model.py | 880 ++++++++++++++++++ .../data/_metrics/test_metrics_controller.py | 98 ++ 7 files changed, 1556 insertions(+) create mode 100644 google/cloud/bigtable/data/_metrics/__init__.py create mode 100644 google/cloud/bigtable/data/_metrics/data_model.py create mode 100644 google/cloud/bigtable/data/_metrics/handlers/_base.py create mode 100644 google/cloud/bigtable/data/_metrics/metrics_controller.py create mode 100644 tests/unit/data/_metrics/__init__.py create mode 100644 tests/unit/data/_metrics/test_data_model.py create mode 100644 tests/unit/data/_metrics/test_metrics_controller.py diff --git a/google/cloud/bigtable/data/_metrics/__init__.py b/google/cloud/bigtable/data/_metrics/__init__.py new file mode 100644 index 000000000..43b8b6139 --- /dev/null +++ b/google/cloud/bigtable/data/_metrics/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from google.cloud.bigtable.data._metrics.metrics_controller import ( + BigtableClientSideMetricsController, +) + +from google.cloud.bigtable.data._metrics.data_model import OperationType +from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + +__all__ = ( + "BigtableClientSideMetricsController", + "OperationType", + "ActiveOperationMetric", +) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py new file mode 100644 index 000000000..b48686e10 --- /dev/null +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -0,0 +1,446 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Callable, Any, Tuple, cast, TYPE_CHECKING + +import time +import re +import logging +import uuid + +from enum import Enum +from functools import lru_cache +from dataclasses import dataclass +from dataclasses import field +from grpc import StatusCode + +import google.cloud.bigtable.data.exceptions as bt_exceptions +from google.cloud.bigtable_v2.types.response_params import ResponseParams +from google.protobuf.message import DecodeError + +if TYPE_CHECKING: + from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler + from google.cloud.bigtable.data._helpers import BackoffGenerator + + +LOGGER = logging.getLogger(__name__) + +# default values for zone and cluster data, if not captured +DEFAULT_ZONE = "global" +DEFAULT_CLUSTER_ID = "unspecified" + +# keys for parsing metadata blobs +BIGTABLE_METADATA_KEY = "x-goog-ext-425905942-bin" +SERVER_TIMING_METADATA_KEY = "server-timing" +SERVER_TIMING_REGEX = re.compile(r".*gfet4t7;\s*dur=(\d+\.?\d*).*") + +INVALID_STATE_ERROR = "Invalid state for {}: {}" + +OPERATION_INTERCEPTOR_METADATA_KEY = 'x-goog-operation-key' + + +class OperationType(Enum): + """Enum for the type of operation being performed.""" + + READ_ROWS = "ReadRows" + SAMPLE_ROW_KEYS = "SampleRowKeys" + BULK_MUTATE_ROWS = "MutateRows" + MUTATE_ROW = "MutateRow" + CHECK_AND_MUTATE = "CheckAndMutateRow" + READ_MODIFY_WRITE = "ReadModifyWriteRow" + + +class OperationState(Enum): + """Enum for the state of the active operation.""" + + CREATED = 0 + ACTIVE_ATTEMPT = 1 + BETWEEN_ATTEMPTS = 2 + COMPLETED = 3 + + +@dataclass(frozen=True) +class CompletedAttemptMetric: + """ + An immutable dataclass representing the data associated with a + completed rpc attempt. + + Operation-level fields (eg. type, cluster, zone) are stored on the + corresponding CompletedOperationMetric or ActiveOperationMetric object. + """ + + duration_ns: int + end_status: StatusCode + first_response_latency_ns: int | None = None + gfe_latency_ns: int | None = None + application_blocking_time_ns: int = 0 + backoff_before_attempt_ns: int = 0 + grpc_throttling_time_ns: int = 0 + + +@dataclass(frozen=True) +class CompletedOperationMetric: + """ + An immutable dataclass representing the data associated with a + completed rpc operation. + + Attempt-level fields (eg. duration, latencies, etc) are stored on the + corresponding CompletedAttemptMetric object. + """ + + op_type: OperationType + uuid: str + duration_ns: int + completed_attempts: list[CompletedAttemptMetric] + final_status: StatusCode + cluster_id: str + zone: str + is_streaming: bool + flow_throttling_time_ns: int = 0 + + +@dataclass +class ActiveAttemptMetric: + """ + A dataclass representing the data associated with an rpc attempt that is + currently in progress. Fields are mutable and may be optional. + """ + + # keep monotonic timestamps for active attempts + start_time_ns: int = field(default_factory=time.monotonic_ns) + # the time it takes to recieve the first response from the server, in nanoseconds + # currently only tracked for ReadRows + first_response_latency_ns: int | None = None + # the time taken by the backend, in nanoseconds. Taken from response header + gfe_latency_ns: int | None = None + # time waiting on user to process the response, in nanoseconds + # currently only relevant for ReadRows + application_blocking_time_ns: int = 0 + # backoff time is added to application_blocking_time_ns + backoff_before_attempt_ns: int = 0 + # time waiting on grpc channel, in nanoseconds + # TODO: capture grpc_throttling_time + grpc_throttling_time_ns: int = 0 + + +@dataclass +class ActiveOperationMetric: + """ + A dataclass representing the data associated with an rpc operation that is + currently in progress. Fields are mutable and may be optional. + """ + + op_type: OperationType + uuid: str = str(uuid.uuid4()) + backoff_generator: BackoffGenerator | None = None + # keep monotonic timestamps for active operations + start_time_ns: int = field(default_factory=time.monotonic_ns) + active_attempt: ActiveAttemptMetric | None = None + cluster_id: str | None = None + zone: str | None = None + completed_attempts: list[CompletedAttemptMetric] = field(default_factory=list) + is_streaming: bool = False # only True for read_rows operations + was_completed: bool = False + handlers: list[MetricsHandler] = field(default_factory=list) + # time waiting on flow control, in nanoseconds + flow_throttling_time_ns: int = 0 + + @property + def interceptor_metadata(self) -> tuple[str, str]: + return OPERATION_INTERCEPTOR_METADATA_KEY, self.uuid + + @property + def state(self) -> OperationState: + if self.was_completed: + return OperationState.COMPLETED + elif self.active_attempt is None: + if self.completed_attempts: + return OperationState.BETWEEN_ATTEMPTS + else: + return OperationState.CREATED + else: + return OperationState.ACTIVE_ATTEMPT + + def start(self) -> None: + """ + Optionally called to mark the start of the operation. If not called, + the operation will be started at initialization. + + Assumes operation is in CREATED state. + """ + if self.state != OperationState.CREATED: + return self._handle_error(INVALID_STATE_ERROR.format("start", self.state)) + self.start_time_ns = time.monotonic_ns() + + def start_attempt(self) -> None: + """ + Called to initiate a new attempt for the operation. + + Assumes operation is in either CREATED or BETWEEN_ATTEMPTS states + """ + if ( + self.state != OperationState.BETWEEN_ATTEMPTS + and self.state != OperationState.CREATED + ): + return self._handle_error( + INVALID_STATE_ERROR.format("start_attempt", self.state) + ) + + # find backoff value + if self.backoff_generator and len(self.completed_attempts) > 0: + # find the attempt's backoff by sending attempt number to generator + # generator will return the backoff time in seconds, so convert to nanoseconds + backoff = self.backoff_generator.get_attempt_backoff( + len(self.completed_attempts) - 1 + ) + backoff_ns = int(backoff * 1e9) + else: + backoff_ns = 0 + + self.active_attempt = ActiveAttemptMetric(backoff_before_attempt_ns=backoff_ns) + + def add_response_metadata(self, metadata: dict[str, bytes | str]) -> None: + """ + Attach trailing metadata to the active attempt. + + If not called, default values for the metadata will be used. + + Assumes operation is in ACTIVE_ATTEMPT state. + + Args: + - metadata: the metadata as extracted from the grpc call + """ + if self.state != OperationState.ACTIVE_ATTEMPT: + return self._handle_error( + INVALID_STATE_ERROR.format("add_response_metadata", self.state) + ) + if self.cluster_id is None or self.zone is None: + # BIGTABLE_METADATA_KEY should give a binary-encoded ResponseParams proto + blob = cast(bytes, metadata.get(BIGTABLE_METADATA_KEY)) + if blob: + parse_result = self._parse_response_metadata_blob(blob) + if parse_result is not None: + cluster, zone = parse_result + if cluster: + self.cluster_id = cluster + if zone: + self.zone = zone + else: + self._handle_error( + f"Failed to decode {BIGTABLE_METADATA_KEY} metadata: {blob!r}" + ) + # SERVER_TIMING_METADATA_KEY should give a string with the server-latency headers + timing_header = cast(str, metadata.get(SERVER_TIMING_METADATA_KEY)) + if timing_header: + timing_data = SERVER_TIMING_REGEX.match(timing_header) + if timing_data and self.active_attempt: + gfe_latency_ms = float(timing_data.group(1)) + self.active_attempt.gfe_latency_ns = int(gfe_latency_ms * 1e6) + + @staticmethod + @lru_cache(maxsize=32) + def _parse_response_metadata_blob(blob: bytes) -> Tuple[str, str] | None: + """ + Parse the response metadata blob and return a tuple of cluster and zone. + + Function is cached to avoid parsing the same blob multiple times. + + Args: + - blob: the metadata blob as extracted from the grpc call + Returns: + - a tuple of cluster_id and zone, or None if parsing failed + """ + try: + proto = ResponseParams.pb().FromString(blob) + return proto.cluster_id, proto.zone_id + except (DecodeError, TypeError): + # failed to parse metadata + return None + + def attempt_first_response(self) -> None: + """ + Called to mark the timestamp of the first completed response for the + active attempt. + + Assumes operation is in ACTIVE_ATTEMPT state. + """ + if self.state != OperationState.ACTIVE_ATTEMPT or self.active_attempt is None: + return self._handle_error( + INVALID_STATE_ERROR.format("attempt_first_response", self.state) + ) + if self.active_attempt.first_response_latency_ns is not None: + return self._handle_error("Attempt already received first response") + self.active_attempt.first_response_latency_ns = ( + time.monotonic_ns() - self.active_attempt.start_time_ns + ) + + def end_attempt_with_status(self, status: StatusCode | Exception) -> None: + """ + Called to mark the end of an attempt for the operation. + + Typically, this is used to mark a retryable error. If a retry will not + be attempted, `end_with_status` or `end_with_success` should be used + to finalize the operation along with the attempt. + + Assumes operation is in ACTIVE_ATTEMPT state. + + Args: + - status: The status of the attempt. + """ + if self.state != OperationState.ACTIVE_ATTEMPT or self.active_attempt is None: + return self._handle_error( + INVALID_STATE_ERROR.format("end_attempt_with_status", self.state) + ) + if isinstance(status, Exception): + status = self._exc_to_status(status) + complete_attempt = CompletedAttemptMetric( + first_response_latency_ns=self.active_attempt.first_response_latency_ns, + duration_ns=time.monotonic_ns() - self.active_attempt.start_time_ns, + end_status=status, + gfe_latency_ns=self.active_attempt.gfe_latency_ns, + application_blocking_time_ns=self.active_attempt.application_blocking_time_ns, + backoff_before_attempt_ns=self.active_attempt.backoff_before_attempt_ns, + grpc_throttling_time_ns=self.active_attempt.grpc_throttling_time_ns, + ) + self.completed_attempts.append(complete_attempt) + self.active_attempt = None + for handler in self.handlers: + handler.on_attempt_complete(complete_attempt, self) + + def end_with_status(self, status: StatusCode | Exception) -> None: + """ + Called to mark the end of the operation. If there is an active attempt, + end_attempt_with_status will be called with the same status. + + Assumes operation is not already in COMPLETED state. + + Causes on_operation_completed to be called for each registered handler. + + Args: + - status: The status of the operation. + """ + if self.state == OperationState.COMPLETED: + return self._handle_error( + INVALID_STATE_ERROR.format("end_with_status", self.state) + ) + final_status = ( + self._exc_to_status(status) if isinstance(status, Exception) else status + ) + if self.state == OperationState.ACTIVE_ATTEMPT: + self.end_attempt_with_status(final_status) + self.was_completed = True + finalized = CompletedOperationMetric( + op_type=self.op_type, + uuid=self.uuid, + completed_attempts=self.completed_attempts, + duration_ns=time.monotonic_ns() - self.start_time_ns, + final_status=final_status, + cluster_id=self.cluster_id or DEFAULT_CLUSTER_ID, + zone=self.zone or DEFAULT_ZONE, + is_streaming=self.is_streaming, + flow_throttling_time_ns=self.flow_throttling_time_ns, + ) + for handler in self.handlers: + handler.on_operation_complete(finalized) + + def end_with_success(self): + """ + Called to mark the end of the operation with a successful status. + + Assumes operation is not already in COMPLETED state. + + Causes on_operation_completed to be called for each registered handler. + """ + return self.end_with_status(StatusCode.OK) + + def build_wrapped_predicate( + self, inner_predicate: Callable[[Exception], bool] + ) -> Callable[[Exception], bool]: + """ + Wrapps a predicate to include metrics tracking. Any call to the resulting predicate + is assumed to be an rpc failure, and will either mark the end of the active attempt + or the end of the operation. + + Args: + - predicate: The predicate to wrap. + """ + + def wrapped_predicate(exc: Exception) -> bool: + inner_result = inner_predicate(exc) + if inner_result: + self.end_attempt_with_status(exc) + else: + self.end_with_status(exc) + return inner_result + + return wrapped_predicate + + @staticmethod + def _exc_to_status(exc: Exception) -> StatusCode: + """ + Extracts the grpc status code from an exception. + + Exception groups and wrappers will be parsed to find the underlying + grpc Exception. + + If the exception is not a grpc exception, will return StatusCode.UNKNOWN. + + Args: + - exc: The exception to extract the status code from. + """ + if isinstance(exc, bt_exceptions._BigtableExceptionGroup): + exc = exc.exceptions[-1] + if hasattr(exc, "grpc_status_code") and exc.grpc_status_code is not None: + return exc.grpc_status_code + if ( + exc.__cause__ + and hasattr(exc.__cause__, "grpc_status_code") + and exc.__cause__.grpc_status_code is not None + ): + return exc.__cause__.grpc_status_code + return StatusCode.UNKNOWN + + @staticmethod + def _handle_error(message: str) -> None: + """ + log error metric system error messages + + Args: + - message: The message to include in the exception or warning. + """ + full_message = f"Error in Bigtable Metrics: {message}" + LOGGER.warning(full_message) + + async def __aenter__(self): + """ + Implements the async context manager protocol for wrapping unary calls + + Using the operation's context manager provides assurances that the operation + is always closed when complete, with the proper status code automaticallty + detected when an exception is raised. + """ + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """ + Implements the async context manager protocol for wrapping unary calls + + The operation is automatically ended on exit, with the status determined + by the exception type and value. + """ + if exc_val is None: + self.end_with_success() + else: + self.end_with_status(exc_val) \ No newline at end of file diff --git a/google/cloud/bigtable/data/_metrics/handlers/_base.py b/google/cloud/bigtable/data/_metrics/handlers/_base.py new file mode 100644 index 000000000..72f5aa550 --- /dev/null +++ b/google/cloud/bigtable/data/_metrics/handlers/_base.py @@ -0,0 +1,35 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric +from google.cloud.bigtable.data._metrics.data_model import CompletedAttemptMetric +from google.cloud.bigtable.data._metrics.data_model import CompletedOperationMetric + + +class MetricsHandler: + """ + Base class for all metrics handlers. Metrics handlers will receive callbacks + when operations and attempts are completed, and can use this information to + update some external metrics system. + """ + + def __init__(self, **kwargs): + pass + + def on_operation_complete(self, op: CompletedOperationMetric) -> None: + pass + + def on_attempt_complete( + self, attempt: CompletedAttemptMetric, op: ActiveOperationMetric + ) -> None: + pass diff --git a/google/cloud/bigtable/data/_metrics/metrics_controller.py b/google/cloud/bigtable/data/_metrics/metrics_controller.py new file mode 100644 index 000000000..52d669227 --- /dev/null +++ b/google/cloud/bigtable/data/_metrics/metrics_controller.py @@ -0,0 +1,72 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric +from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler +from google.cloud.bigtable.data._metrics.data_model import OperationType + +if TYPE_CHECKING: + from google.cloud.bigtable.data._async.metrics_interceptor import AsyncBigtableMetricsInterceptor + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import BigtableMetricsInterceptor + +class BigtableClientSideMetricsController: + """ + BigtableClientSideMetricsController is responsible for managing the + lifecycle of the metrics system. The Bigtable client library will + use this class to create new operations. Each operation will be + registered with the handlers associated with this controller. + """ + + def __init__(self, + interceptor: AsyncBigtableMetricsInterceptor | BigtableMetricsInterceptor, + handlers: list[MetricsHandler] | None = None, + **kwargs + ): + """ + Initializes the metrics controller. + + Args: + - interceptor: A metrics interceptor to use for triggering Operation lifecycle events + - handlers: A list of MetricsHandler objects to subscribe to metrics events. + - **kwargs: Optional arguments to pass to the metrics handlers. + """ + self.interceptor = interceptor + self.handlers: list[MetricsHandler] = handlers or [] + if handlers is None: + # handlers not given. Use default handlers. + # TODO: add default handlers + pass + + def add_handler(self, handler: MetricsHandler) -> None: + """ + Add a new handler to the list of handlers. + + Args: + - handler: A MetricsHandler object to add to the list of subscribed handlers. + """ + self.handlers.append(handler) + + def create_operation( + self, op_type: OperationType, **kwargs + ) -> ActiveOperationMetric: + """ + Creates a new operation and registers it with the subscribed handlers. + """ + handlers = self.handlers + kwargs.pop("handlers", []) + new_op = ActiveOperationMetric(op_type, **kwargs, handlers=handlers) + self.interceptor.register_operation(new_op) + return new_op \ No newline at end of file diff --git a/tests/unit/data/_metrics/__init__.py b/tests/unit/data/_metrics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/_metrics/test_data_model.py b/tests/unit/data/_metrics/test_data_model.py new file mode 100644 index 000000000..0a136075f --- /dev/null +++ b/tests/unit/data/_metrics/test_data_model.py @@ -0,0 +1,880 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import pytest +import mock + +from google.cloud.bigtable.data._metrics.data_model import OperationState as State +from google.cloud.bigtable_v2.types import ResponseParams + + +class TestActiveOperationMetric: + def _make_one(self, *args, **kwargs): + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + + return ActiveOperationMetric(*args, **kwargs) + + def test_ctor_defaults(self): + """ + create an instance with default values + """ + mock_type = mock.Mock() + metric = self._make_one(mock_type) + assert metric.op_type == mock_type + assert abs(metric.start_time_ns - time.monotonic_ns()) < 1e6 # 1ms buffer + assert metric.active_attempt is None + assert metric.cluster_id is None + assert metric.zone is None + assert len(metric.completed_attempts) == 0 + assert metric.was_completed is False + assert len(metric.handlers) == 0 + assert metric.is_streaming is False + assert metric.flow_throttling_time_ns == 0 + + def test_ctor_explicit(self): + """ + test with explicit arguments + """ + expected_type = mock.Mock() + expected_start_time_ns = 7 + expected_active_attempt = mock.Mock() + expected_cluster_id = "cluster" + expected_zone = "zone" + expected_completed_attempts = [mock.Mock()] + expected_was_completed = True + expected_handlers = [mock.Mock()] + expected_is_streaming = True + expected_flow_throttling = 12 + metric = self._make_one( + op_type=expected_type, + start_time_ns=expected_start_time_ns, + active_attempt=expected_active_attempt, + cluster_id=expected_cluster_id, + zone=expected_zone, + completed_attempts=expected_completed_attempts, + was_completed=expected_was_completed, + handlers=expected_handlers, + is_streaming=expected_is_streaming, + flow_throttling_time_ns=expected_flow_throttling, + ) + assert metric.op_type == expected_type + assert metric.start_time_ns == expected_start_time_ns + assert metric.active_attempt == expected_active_attempt + assert metric.cluster_id == expected_cluster_id + assert metric.zone == expected_zone + assert metric.completed_attempts == expected_completed_attempts + assert metric.was_completed == expected_was_completed + assert metric.handlers == expected_handlers + assert metric.is_streaming == expected_is_streaming + assert metric.flow_throttling_time_ns == expected_flow_throttling + + def test_state_machine_w_methods(self): + """ + Exercise the state machine by calling methods to move between states + """ + metric = self._make_one(mock.Mock()) + assert metric.state == State.CREATED + metric.start() + assert metric.state == State.CREATED + metric.start_attempt() + assert metric.state == State.ACTIVE_ATTEMPT + metric.end_attempt_with_status(Exception()) + assert metric.state == State.BETWEEN_ATTEMPTS + metric.start_attempt() + assert metric.state == State.ACTIVE_ATTEMPT + metric.end_with_success() + assert metric.state == State.COMPLETED + + def test_state_machine_w_state(self): + """ + Exercise state machine by directly manupulating state variables + + relevant variables are: active_attempt, completed_attempts, was_completed + """ + metric = self._make_one(mock.Mock()) + for was_completed_value in [False, True]: + metric.was_completed = was_completed_value + for active_operation_value in [None, mock.Mock()]: + metric.active_attempt = active_operation_value + for completed_attempts_value in [[], [mock.Mock()]]: + metric.completed_attempts = completed_attempts_value + if was_completed_value: + assert metric.state == State.COMPLETED + elif active_operation_value is not None: + assert metric.state == State.ACTIVE_ATTEMPT + elif completed_attempts_value: + assert metric.state == State.BETWEEN_ATTEMPTS + else: + assert metric.state == State.CREATED + + @pytest.mark.parametrize( + "method,args,valid_states,error_method_name", + [ + ("start", (), (State.CREATED,), None), + ("start_attempt", (), (State.CREATED, State.BETWEEN_ATTEMPTS), None), + ("add_response_metadata", ({},), (State.ACTIVE_ATTEMPT,), None), + ("attempt_first_response", (), (State.ACTIVE_ATTEMPT,), None), + ("end_attempt_with_status", (mock.Mock(),), (State.ACTIVE_ATTEMPT,), None), + ( + "end_with_status", + (mock.Mock(),), + ( + State.CREATED, + State.ACTIVE_ATTEMPT, + State.BETWEEN_ATTEMPTS, + ), + None, + ), + ( + "end_with_success", + (), + ( + State.CREATED, + State.ACTIVE_ATTEMPT, + State.BETWEEN_ATTEMPTS, + ), + "end_with_status", + ), + ], + ids=lambda x: x if isinstance(x, str) else "", + ) + def test_error_invalid_states(self, method, args, valid_states, error_method_name): + """ + each method only works for certain states. Make sure _handle_error is called for invalid states + """ + cls = type(self._make_one(mock.Mock())) + invalid_states = set(State) - set(valid_states) + error_method_name = error_method_name or method + for state in invalid_states: + with mock.patch.object(cls, "_handle_error") as mock_handle_error: + mock_handle_error.return_value = None + metric = self._make_one(mock.Mock()) + if state == State.ACTIVE_ATTEMPT: + metric.active_attempt = mock.Mock() + elif state == State.BETWEEN_ATTEMPTS: + metric.completed_attempts.append(mock.Mock()) + elif state == State.COMPLETED: + metric.was_completed = True + return_obj = getattr(metric, method)(*args) + assert return_obj is None + assert mock_handle_error.call_count == 1 + assert ( + mock_handle_error.call_args[0][0] + == f"Invalid state for {error_method_name}: {state}" + ) + + def test_start(self): + """ + calling start op operation should reset start_time + """ + orig_time = 0 + metric = self._make_one(mock.Mock(), start_time_ns=orig_time) + assert abs(metric.start_time_ns - time.monotonic_ns()) > 1e6 # 1ms buffer + metric.start() + assert metric.start_time_ns != orig_time + assert abs(metric.start_time_ns - time.monotonic_ns()) < 1e6 # 1ms buffer + # should remain in CREATED state after completing + assert metric.state == State.CREATED + + def test_start_attempt(self): + """ + calling start_attempt should create a new emptu atempt metric + """ + from google.cloud.bigtable.data._metrics.data_model import ActiveAttemptMetric + + metric = self._make_one(mock.Mock()) + assert metric.active_attempt is None + metric.start_attempt() + assert isinstance(metric.active_attempt, ActiveAttemptMetric) + # make sure it was initialized with the correct values + assert ( + abs(time.monotonic_ns() - metric.active_attempt.start_time_ns) < 1e6 + ) # 1ms buffer + assert metric.active_attempt.first_response_latency_ns is None + assert metric.active_attempt.gfe_latency_ns is None + assert metric.active_attempt.grpc_throttling_time_ns == 0 + # should be in ACTIVE_ATTEMPT state after completing + assert metric.state == State.ACTIVE_ATTEMPT + + def test_start_attempt_with_backoff_generator(self): + """ + If operation has a backoff generator, it should be used to attach backoff + times to attempts + """ + from google.cloud.bigtable.data._helpers import BackoffGenerator + + generator = BackoffGenerator() + # pre-seed generator with exepcted values + generator.history = list(range(10)) + metric = self._make_one(mock.Mock(), backoff_generator=generator) + # initialize generator + next(metric.backoff_generator) + metric.start_attempt() + assert len(metric.completed_attempts) == 0 + # first attempt should always be 0 + assert metric.active_attempt.backoff_before_attempt_ns == 0 + # later attempts should have their attempt number as backoff time + for i in range(10): + metric.end_attempt_with_status(mock.Mock()) + assert len(metric.completed_attempts) == i + 1 + metric.start_attempt() + # expect the backoff to be converted froms seconds to ns + assert metric.active_attempt.backoff_before_attempt_ns == (i * 1e9) + + @pytest.mark.parametrize( + "start_cluster,start_zone,metadata_proto,end_cluster,end_zone", + [ + (None, None, None, None, None), + ("orig_cluster", "orig_zone", None, "orig_cluster", "orig_zone"), + (None, None, ResponseParams(), None, None), + ( + "orig_cluster", + "orig_zone", + ResponseParams(), + "orig_cluster", + "orig_zone", + ), + ( + None, + None, + ResponseParams(cluster_id="test-cluster", zone_id="us-central1-b"), + "test-cluster", + "us-central1-b", + ), + ( + None, + "filled", + ResponseParams(cluster_id="cluster", zone_id="zone"), + "cluster", + "zone", + ), + (None, "filled", ResponseParams(cluster_id="cluster"), "cluster", "filled"), + (None, "filled", ResponseParams(zone_id="zone"), None, "zone"), + ( + "filled", + None, + ResponseParams(cluster_id="cluster", zone_id="zone"), + "cluster", + "zone", + ), + ("filled", None, ResponseParams(cluster_id="cluster"), "cluster", None), + ("filled", None, ResponseParams(zone_id="zone"), "filled", "zone"), + ], + ) + def test_add_response_metadata_cbt_header( + self, start_cluster, start_zone, metadata_proto, end_cluster, end_zone + ): + """ + calling add_response_metadata should update fields based on grpc response metadata + The x-goog-ext-425905942-bin field contains cluster and zone info + """ + import grpc + + cls = type(self._make_one(mock.Mock())) + with mock.patch.object(cls, "_handle_error") as mock_handle_error: + metric = self._make_one( + mock.Mock(), cluster_id=start_cluster, zone=start_zone + ) + metric.active_attempt = mock.Mock() + metric.active_attempt.gfe_latency_ns = None + metadata = grpc.aio.Metadata() + if metadata_proto is not None: + metadata["x-goog-ext-425905942-bin"] = ResponseParams.serialize( + metadata_proto + ) + metric.add_response_metadata(metadata) + assert metric.cluster_id == end_cluster + assert metric.zone == end_zone + # should remain in ACTIVE_ATTEMPT state after completing + assert metric.state == State.ACTIVE_ATTEMPT + # no errors encountered + assert mock_handle_error.call_count == 0 + # gfe latency should not be touched + assert metric.active_attempt.gfe_latency_ns is None + + @pytest.mark.parametrize( + "metadata_field", + [ + b"cluster", + "cluster zone", # expect bytes + ], + ) + def test_add_response_metadata_cbt_header_w_error(self, metadata_field): + """ + If the x-goog-ext-425905942-bin field is present, but not structured properly, + _handle_error should be called + + Extra fields should not result in parsingerror + """ + import grpc + + cls = type(self._make_one(mock.Mock())) + with mock.patch.object(cls, "_handle_error") as mock_handle_error: + metric = self._make_one(mock.Mock()) + metric.cluster_id = None + metric.zone = None + metric.active_attempt = mock.Mock() + metadata = grpc.aio.Metadata() + metadata["x-goog-ext-425905942-bin"] = metadata_field + metric.add_response_metadata(metadata) + # should remain in ACTIVE_ATTEMPT state after completing + assert metric.state == State.ACTIVE_ATTEMPT + # no errors encountered + assert mock_handle_error.call_count == 1 + assert ( + "Failed to decode x-goog-ext-425905942-bin metadata:" + in mock_handle_error.call_args[0][0] + ) + assert str(metadata_field) in mock_handle_error.call_args[0][0] + + @pytest.mark.parametrize( + "metadata_field,expected_latency_ns", + [ + (None, None), + ("gfet4t7; dur=1000", 1000e6), + ("gfet4t7; dur=1000.0", 1000e6), + ("gfet4t7; dur=1000.1", 1000.1e6), + ("gcp; dur=15, gfet4t7; dur=300", 300e6), + ("gfet4t7;dur=350,gcp;dur=12", 350e6), + ("ignore_megfet4t7;dur=90ignore_me", 90e6), + ("gfet4t7;dur=2000", 2000e6), + ("gfet4t7; dur=0.001", 1000), + ("gfet4t7; dur=0.000001", 1), + ("gfet4t7; dur=0.0000001", 0), # below recording resolution + ("gfet4t7; dur=0", 0), + ("gfet4t7; dur=empty", None), + ("gfet4t7;", None), + ("", None), + ], + ) + def test_add_response_metadata_server_timing_header( + self, metadata_field, expected_latency_ns + ): + """ + calling add_response_metadata should update fields based on grpc response metadata + The server-timing field contains gfle latency info + """ + import grpc + + cls = type(self._make_one(mock.Mock())) + with mock.patch.object(cls, "_handle_error") as mock_handle_error: + metric = self._make_one(mock.Mock()) + metric.active_attempt = mock.Mock() + metric.active_attempt.gfe_latency_ns = None + metadata = grpc.aio.Metadata() + if metadata_field: + metadata["server-timing"] = metadata_field + metric.add_response_metadata(metadata) + if metric.active_attempt.gfe_latency_ns is None: + assert expected_latency_ns is None + else: + assert metric.active_attempt.gfe_latency_ns == int(expected_latency_ns) + # should remain in ACTIVE_ATTEMPT state after completing + assert metric.state == State.ACTIVE_ATTEMPT + # no errors encountered + assert mock_handle_error.call_count == 0 + # cluster and zone should not be touched + assert metric.cluster_id is None + assert metric.zone is None + + def test_attempt_first_response(self): + cls = type(self._make_one(mock.Mock())) + with mock.patch.object(cls, "_handle_error") as mock_handle_error: + metric = self._make_one(mock.Mock()) + metric.start_attempt() + metric.active_attempt.start_time_ns = 0 + metric.attempt_first_response() + got_latency_ns = metric.active_attempt.first_response_latency_ns + # latency should be equal to current time + assert abs(got_latency_ns - time.monotonic_ns()) < 1e6 # 1ms + # should remain in ACTIVE_ATTEMPT state after completing + assert metric.state == State.ACTIVE_ATTEMPT + # no errors encountered + assert mock_handle_error.call_count == 0 + # calling it again should cause an error + metric.attempt_first_response() + assert mock_handle_error.call_count == 1 + assert ( + mock_handle_error.call_args[0][0] + == "Attempt already received first response" + ) + # value should not be changed + assert metric.active_attempt.first_response_latency_ns == got_latency_ns + + def test_end_attempt_with_status(self): + """ + ending the attempt should: + - add one to completed_attempts + - reset active_attempt to None + - update state + """ + expected_latency_ns = 9 + expected_start_time = 1 + expected_status = object() + expected_gfe_latency_ns = 5 + expected_app_blocking = 12 + expected_backoff = 2 + expected_grpc_throttle = 3 + + metric = self._make_one(mock.Mock()) + assert metric.active_attempt is None + assert len(metric.completed_attempts) == 0 + metric.start_attempt() + metric.active_attempt.start_time_ns = expected_start_time + metric.active_attempt.gfe_latency_ns = expected_gfe_latency_ns + metric.active_attempt.first_response_latency_ns = expected_latency_ns + metric.active_attempt.application_blocking_time_ns = expected_app_blocking + metric.active_attempt.backoff_before_attempt_ns = expected_backoff + metric.active_attempt.grpc_throttling_time_ns = expected_grpc_throttle + metric.end_attempt_with_status(expected_status) + assert len(metric.completed_attempts) == 1 + got_attempt = metric.completed_attempts[0] + expected_duration = time.monotonic_ns() - expected_start_time + assert abs(got_attempt.duration_ns - expected_duration) < 10e6 # within 10ms + assert got_attempt.first_response_latency_ns == expected_latency_ns + assert got_attempt.grpc_throttling_time_ns == expected_grpc_throttle + assert got_attempt.end_status == expected_status + assert got_attempt.gfe_latency_ns == expected_gfe_latency_ns + assert got_attempt.application_blocking_time_ns == expected_app_blocking + assert got_attempt.backoff_before_attempt_ns == expected_backoff + # state should be changed to BETWEEN_ATTEMPTS + assert metric.state == State.BETWEEN_ATTEMPTS + + def test_end_attempt_with_status_w_exception(self): + """ + exception inputs should be converted to grpc status objects + """ + input_status = ValueError("test") + expected_status = object() + + metric = self._make_one(mock.Mock()) + metric.start_attempt() + with mock.patch.object( + metric, "_exc_to_status", return_value=expected_status + ) as mock_exc_to_status: + metric.end_attempt_with_status(input_status) + assert mock_exc_to_status.call_count == 1 + assert mock_exc_to_status.call_args[0][0] == input_status + assert metric.completed_attempts[0].end_status == expected_status + + def test_end_with_status(self): + """ + ending the operation should: + - end active attempt + - mark operation as completed + - update handlers + """ + from google.cloud.bigtable.data._metrics.data_model import ActiveAttemptMetric + + expected_attempt_start_time = 0 + expected_attempt_first_response_latency_ns = 9 + expected_attempt_gfe_latency_ns = 5 + expected_flow_time = 16 + + expected_status = object() + expected_type = object() + expected_start_time = 1 + expected_cluster = object() + expected_zone = object() + is_streaming = object() + + handlers = [mock.Mock(), mock.Mock()] + metric = self._make_one( + expected_type, handlers=handlers, start_time_ns=expected_start_time + ) + metric.cluster_id = expected_cluster + metric.zone = expected_zone + metric.is_streaming = is_streaming + metric.flow_throttling_time_ns = expected_flow_time + attempt = ActiveAttemptMetric( + start_time_ns=expected_attempt_start_time, + first_response_latency_ns=expected_attempt_first_response_latency_ns, + gfe_latency_ns=expected_attempt_gfe_latency_ns, + ) + metric.active_attempt = attempt + metric.end_with_status(expected_status) + # test that ActiveOperation was updated to terminal state + assert metric.state == State.COMPLETED + assert metric.was_completed is True + assert metric.active_attempt is None + assert len(metric.completed_attempts) == 1 + # check that finalized operation was passed to handlers + for h in handlers: + assert h.on_operation_complete.call_count == 1 + assert len(h.on_operation_complete.call_args[0]) == 1 + called_with = h.on_operation_complete.call_args[0][0] + assert called_with.op_type == expected_type + expected_duration = time.monotonic_ns() - expected_start_time + assert ( + abs(called_with.duration_ns - expected_duration) < 10e6 + ) # within 10ms + assert called_with.final_status == expected_status + assert called_with.cluster_id == expected_cluster + assert called_with.zone == expected_zone + assert called_with.is_streaming == is_streaming + assert called_with.flow_throttling_time_ns == expected_flow_time + # check the attempt + assert len(called_with.completed_attempts) == 1 + final_attempt = called_with.completed_attempts[0] + assert ( + final_attempt.first_response_latency_ns + == expected_attempt_first_response_latency_ns + ) + assert final_attempt.gfe_latency_ns == expected_attempt_gfe_latency_ns + assert final_attempt.end_status == expected_status + expected_duration = time.monotonic_ns() - expected_attempt_start_time + assert ( + abs(final_attempt.duration_ns - expected_duration) < 10e6 + ) # within 10ms + + def test_end_with_status_w_exception(self): + """ + exception inputs should be converted to grpc status objects + """ + input_status = ValueError("test") + expected_status = object() + handlers = [mock.Mock()] + + metric = self._make_one(mock.Mock(), handlers=handlers) + metric.start_attempt() + with mock.patch.object( + metric, "_exc_to_status", return_value=expected_status + ) as mock_exc_to_status: + metric.end_with_status(input_status) + assert mock_exc_to_status.call_count == 1 + assert mock_exc_to_status.call_args[0][0] == input_status + assert metric.completed_attempts[0].end_status == expected_status + final_op = handlers[0].on_operation_complete.call_args[0][0] + assert final_op.final_status == expected_status + + def test_end_with_success(self): + """ + end with success should be a pass-through helper for end_with_status + """ + from grpc import StatusCode + + inner_result = object() + + metric = self._make_one(mock.Mock()) + with mock.patch.object(metric, "end_with_status") as mock_end_with_status: + mock_end_with_status.return_value = inner_result + got_result = metric.end_with_success() + assert mock_end_with_status.call_count == 1 + assert mock_end_with_status.call_args[0][0] == StatusCode.OK + assert got_result is inner_result + + def test_end_on_empty_operation(self): + """ + Should be able to end an operation without any attempts + """ + from grpc import StatusCode + + handlers = [mock.Mock()] + metric = self._make_one(mock.Mock(), handlers=handlers) + metric.end_with_success() + assert metric.state == State.COMPLETED + assert metric.was_completed is True + final_op = handlers[0].on_operation_complete.call_args[0][0] + assert final_op.final_status == StatusCode.OK + assert final_op.completed_attempts == [] + + def test_build_wrapped_predicate(self): + """ + predicate generated by object should terminate attempt or operation + based on passed in predicate + """ + input_exc = ValueError("test") + cls = type(self._make_one(object())) + # ensure predicate is called with the exception + mock_predicate = mock.Mock() + cls.build_wrapped_predicate(mock.Mock(), mock_predicate)(input_exc) + assert mock_predicate.call_count == 1 + assert mock_predicate.call_args[0][0] == input_exc + assert len(mock_predicate.call_args[0]) == 1 + # if predicate is true, end the attempt + mock_instance = mock.Mock() + cls.build_wrapped_predicate(mock_instance, lambda x: True)(input_exc) + assert mock_instance.end_attempt_with_status.call_count == 1 + assert mock_instance.end_attempt_with_status.call_args[0][0] == input_exc + assert len(mock_instance.end_attempt_with_status.call_args[0]) == 1 + # if predicate is false, end the operation + mock_instance = mock.Mock() + cls.build_wrapped_predicate(mock_instance, lambda x: False)(input_exc) + assert mock_instance.end_with_status.call_count == 1 + assert mock_instance.end_with_status.call_args[0][0] == input_exc + assert len(mock_instance.end_with_status.call_args[0]) == 1 + + def test__exc_to_status(self): + """ + Should return grpc_status_code if grpc error, otherwise UNKNOWN + + If BigtableExceptionGroup, use the most recent exception in the group + """ + from grpc import StatusCode + from google.api_core import exceptions as core_exc + from google.cloud.bigtable.data import exceptions as bt_exc + + cls = type(self._make_one(object())) + # unknown for non-grpc errors + assert cls._exc_to_status(ValueError()) == StatusCode.UNKNOWN + assert cls._exc_to_status(RuntimeError()) == StatusCode.UNKNOWN + # grpc status code for grpc errors + assert ( + cls._exc_to_status(core_exc.InvalidArgument("msg")) + == StatusCode.INVALID_ARGUMENT + ) + assert cls._exc_to_status(core_exc.NotFound("msg")) == StatusCode.NOT_FOUND + assert ( + cls._exc_to_status(core_exc.AlreadyExists("msg")) + == StatusCode.ALREADY_EXISTS + ) + assert ( + cls._exc_to_status(core_exc.PermissionDenied("msg")) + == StatusCode.PERMISSION_DENIED + ) + cause_exc = core_exc.AlreadyExists("msg") + w_cause = core_exc.DeadlineExceeded("msg") + w_cause.__cause__ = cause_exc + assert cls._exc_to_status(w_cause) == StatusCode.DEADLINE_EXCEEDED + # use cause if available + w_cause = ValueError("msg") + w_cause.__cause__ = cause_exc + cause_exc.grpc_status_code = object() + custom_excs = [ + bt_exc.FailedMutationEntryError(1, mock.Mock(), cause=cause_exc), + bt_exc.FailedQueryShardError(1, {}, cause=cause_exc), + w_cause, + ] + for exc in custom_excs: + assert cls._exc_to_status(exc) == cause_exc.grpc_status_code, exc + # extract most recent exception for bigtable exception groups + exc_groups = [ + bt_exc._BigtableExceptionGroup("", [ValueError(), cause_exc]), + bt_exc.RetryExceptionGroup([RuntimeError(), cause_exc]), + bt_exc.ShardedReadRowsExceptionGroup( + [bt_exc.FailedQueryShardError(1, {}, cause=cause_exc)], [], 2 + ), + bt_exc.MutationsExceptionGroup( + [bt_exc.FailedMutationEntryError(1, mock.Mock(), cause=cause_exc)], 2 + ), + ] + for exc in exc_groups: + assert cls._exc_to_status(exc) == cause_exc.grpc_status_code, exc + + def test__handle_error(self): + """ + handle_error should write log + """ + input_message = "test message" + expected_message = f"Error in Bigtable Metrics: {input_message}" + with mock.patch( + "google.cloud.bigtable.data._metrics.data_model.LOGGER" + ) as logger_mock: + type(self._make_one(object()))._handle_error(input_message) + assert logger_mock.warning.call_count == 1 + assert logger_mock.warning.call_args[0][0] == expected_message + assert len(logger_mock.warning.call_args[0]) == 1 + + @pytest.mark.asyncio + async def test_async_context_manager(self): + """ + Should implement context manager protocol + """ + metric = self._make_one(object()) + with mock.patch.object(metric, "end_with_success") as end_with_success_mock: + end_with_success_mock.side_effect = lambda: metric.end_with_status(object()) + async with metric as context: + assert isinstance(context, type(metric)._AsyncContextManager) + assert context.operation == metric + # inside context manager, still active + assert end_with_success_mock.call_count == 0 + assert metric.state == State.CREATED + # outside context manager, should be ended + assert end_with_success_mock.call_count == 1 + assert metric.state == State.COMPLETED + + @pytest.mark.asyncio + async def test_async_context_manager_exception(self): + """ + Exception within context manager causes end_with_status to be called with error + """ + expected_exc = ValueError("expected") + metric = self._make_one(object()) + with mock.patch.object(metric, "end_with_status") as end_with_status_mock: + try: + async with metric as context: + assert isinstance(context, type(metric)._AsyncContextManager) + assert context.operation == metric + # inside context manager, still active + assert end_with_status_mock.call_count == 0 + assert metric.state == State.CREATED + raise expected_exc + except ValueError as e: + assert e == expected_exc + # outside context manager, should be ended + assert end_with_status_mock.call_count == 1 + assert end_with_status_mock.call_args[0][0] == expected_exc + assert len(end_with_status_mock.call_args[0]) == 1 + + @pytest.mark.asyncio + async def test_metadata_passthrough(self): + """ + add_response_metadata in context manager should defer to wrapped operation + """ + inner_result = object() + fake_metadata = object() + + metric = self._make_one(mock.Mock()) + with mock.patch.object(metric, "add_response_metadata") as mock_add_metadata: + mock_add_metadata.return_value = inner_result + async with metric as context: + result = context.add_response_metadata(fake_metadata) + assert result == inner_result + assert mock_add_metadata.call_count == 1 + assert mock_add_metadata.call_args[0][0] == fake_metadata + assert len(mock_add_metadata.call_args[0]) == 1 + + @pytest.mark.asyncio + async def test_wrap_attempt_fn_success(self): + """ + Context manager's wrap_attempt_fn should wrap an arbitrary function + in operation instrumentation + + Test successful call + - should return the result of the wrapped function + - should call end_with_success + """ + from grpc import StatusCode + + metric = self._make_one(object()) + async with metric as context: + mock_call = mock.AsyncMock() + mock_args = (1, 2, 3) + mock_kwargs = {"a": 1, "b": 2} + inner_fn = lambda *args, **kwargs: mock_call(*args, **kwargs) # noqa + wrapped_fn = context.wrap_attempt_fn(inner_fn, extract_call_metadata=False) + # make the wrapped call + result = await wrapped_fn(*mock_args, **mock_kwargs) + assert result == mock_call.return_value + assert mock_call.call_count == 1 + assert mock_call.call_args[0] == mock_args + assert mock_call.call_args[1] == mock_kwargs + assert mock_call.await_count == 1 + # operation should be still in progress after wrapped fn + # let context manager close it, in case we need to add metadata, etc + assert metric.state == State.ACTIVE_ATTEMPT + # make sure the operation is complete after exiting context manager + assert metric.state == State.COMPLETED + assert len(metric.completed_attempts) == 1 + assert metric.completed_attempts[0].end_status == StatusCode.OK + + @pytest.mark.asyncio + async def test_wrap_attempt_fn_failed_extract_call_metadata(self): + """ + When extract_call_metadata is True, should call add_response_metadata + on operation with output of wrapped function, even if failed + """ + mock_call = mock.AsyncMock() + mock_call.trailing_metadata.return_value = 3 + mock_call.initial_metadata.return_value = 4 + inner_fn = lambda *args, **kwargs: mock_call # noqa + metric = self._make_one(object()) + async with metric as context: + wrapped_fn = context.wrap_attempt_fn(inner_fn, extract_call_metadata=True) + with mock.patch.object( + metric, "add_response_metadata" + ) as mock_add_metadata: + # make the wrapped call. expect exception when awaiting on mock_call + with pytest.raises(TypeError): + await wrapped_fn() + assert mock_add_metadata.call_count == 1 + assert mock_call.trailing_metadata.call_count == 1 + assert mock_call.initial_metadata.call_count == 1 + assert mock_add_metadata.call_args[0][0] == 3 + 4 + + @pytest.mark.asyncio + async def test_wrap_attempt_fn_failed_extract_call_metadata_no_mock(self): + """ + Make sure the metadata is accessible after a failed attempt + """ + import grpc + + mock_call = mock.AsyncMock() + mock_call.trailing_metadata.return_value = grpc.aio.Metadata() + mock_call.initial_metadata.return_value = grpc.aio.Metadata( + ("server-timing", "gfet4t7; dur=5000") + ) + inner_fn = lambda *args, **kwargs: mock_call # noqa + metric = self._make_one(object()) + async with metric as context: + wrapped_fn = context.wrap_attempt_fn(inner_fn, extract_call_metadata=True) + with pytest.raises(TypeError): + await wrapped_fn() + assert metric.active_attempt is None + assert len(metric.completed_attempts) == 1 + assert metric.completed_attempts[0].gfe_latency_ns == 5000e6 # ms to ns + + @pytest.mark.asyncio + async def test_wrap_attempt_fn_failed_attempt(self): + """ + failed attempts should call operation.end_attempt with error + """ + from grpc import StatusCode + + metric = self._make_one(object()) + async with metric as context: + wrapped_fn = context.wrap_attempt_fn( + mock.Mock(), extract_call_metadata=False + ) + # make the wrapped call. expect type error when awaiting response of mock + with pytest.raises(TypeError): + await wrapped_fn() + # should have one failed attempt, but operation still in progress + assert len(metric.completed_attempts) == 1 + assert metric.state == State.BETWEEN_ATTEMPTS + assert metric.active_attempt is None + # unknown status from type error + assert metric.completed_attempts[0].end_status == StatusCode.UNKNOWN + # make sure operation is closed on end + assert metric.state == State.COMPLETED + + @pytest.mark.asyncio + async def test_wrap_attempt_fn_with_retry(self): + """ + wrap_attampt_fn is meant to be used with retry object. Test using them together + """ + from grpc import StatusCode + from google.api_core.retry import AsyncRetry + from google.api_core.exceptions import RetryError + + metric = self._make_one(object()) + with pytest.raises(RetryError): + # should eventually fail due to timeout + async with metric as context: + always_retry = lambda x: True # noqa + retry_obj = AsyncRetry( + predicate=always_retry, timeout=0.05, maximum=0.001 + ) + # mock.Mock will fail on await + double_wrapped_fn = retry_obj( + context.wrap_attempt_fn(mock.Mock(), extract_call_metadata=False) + ) + await double_wrapped_fn() + # make sure operation ended with expected state + assert metric.state == State.COMPLETED + # we expect > 30 retries in 0.05 seconds + assert len(metric.completed_attempts) > 5 + # unknown error due to TyperError + assert metric.completed_attempts[-1].end_status == StatusCode.UNKNOWN diff --git a/tests/unit/data/_metrics/test_metrics_controller.py b/tests/unit/data/_metrics/test_metrics_controller.py new file mode 100644 index 000000000..12cd32c92 --- /dev/null +++ b/tests/unit/data/_metrics/test_metrics_controller.py @@ -0,0 +1,98 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock + + +class TestBigtableClientSideMetricsController: + def _make_one(self, *args, **kwargs): + from google.cloud.bigtable.data._metrics import ( + BigtableClientSideMetricsController, + ) + + return BigtableClientSideMetricsController(*args, **kwargs) + + def test_ctor_defaults(self): + """ + should create instance with GCP Exporter handler by default + """ + instance = self._make_one( + project_id="p", instance_id="i", table_id="t", app_profile_id="a" + ) + assert len(instance.handlers) == 0 + + def ctor_custom_handlers(self): + """ + if handlers are passed to init, use those instead + """ + custom_handler = object() + controller = self._make_one(handlers=[custom_handler]) + assert len(controller.handlers) == 1 + assert controller.handlers[0] is custom_handler + + def test_add_handler(self): + """ + New handlers should be added to list + """ + controller = self._make_one(handlers=[object()]) + initial_handler_count = len(controller.handlers) + new_handler = object() + controller.add_handler(new_handler) + assert len(controller.handlers) == initial_handler_count + 1 + assert controller.handlers[-1] is new_handler + + def test_create_operation_mock(self): + """ + All args should be passed through, as well as the handlers + """ + from google.cloud.bigtable.data._metrics import ActiveOperationMetric + + controller = self._make_one(handlers=[object()]) + arg = object() + kwargs = {"a": 1, "b": 2} + with mock.patch( + "google.cloud.bigtable.data._metrics.ActiveOperationMetric.__init__" + ) as mock_op: + mock_op.return_value = None + op = controller.create_operation(arg, **kwargs) + assert isinstance(op, ActiveOperationMetric) + assert mock_op.call_count == 1 + mock_op.assert_called_with(arg, **kwargs, handlers=controller.handlers) + + def test_create_operation(self): + from google.cloud.bigtable.data._metrics import ActiveOperationMetric + + handler = object() + expected_type = object() + expected_is_streaming = True + expected_zone = object() + controller = self._make_one(handlers=[handler]) + op = controller.create_operation( + expected_type, is_streaming=expected_is_streaming, zone=expected_zone + ) + assert isinstance(op, ActiveOperationMetric) + assert op.op_type is expected_type + assert op.is_streaming is expected_is_streaming + assert op.zone is expected_zone + assert len(op.handlers) == 1 + assert op.handlers[0] is handler + + def test_create_operation_multiple_handlers(self): + orig_handler = object() + new_handler = object() + controller = self._make_one(handlers=[orig_handler]) + op = controller.create_operation(object(), handlers=[new_handler]) + assert len(op.handlers) == 2 + assert orig_handler in op.handlers + assert new_handler in op.handlers From 29dff4dffee7a6edfc44d0db3c5e7b3315ef8d6e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 16:02:18 -0700 Subject: [PATCH 16/60] added back interceptor --- .../data/_async/metrics_interceptor.py | 133 ++++++++++++++++++ .../data/_sync_autogen/metrics_interceptor.py | 120 ++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 google/cloud/bigtable/data/_async/metrics_interceptor.py create mode 100644 google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py new file mode 100644 index 000000000..817782105 --- /dev/null +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -0,0 +1,133 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import time +from typing import Any, Callable +from functools import wraps +from google.cloud.bigtable.data._metrics.data_model import OPERATION_INTERCEPTOR_METADATA_KEY +from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric +from google.cloud.bigtable.data._metrics.data_model import OperationState + +from google.cloud.bigtable.data._cross_sync import CrossSync + +if CrossSync.is_async: + from grpc.aio import UnaryUnaryClientInterceptor + from grpc.aio import UnaryStreamClientInterceptor +else: + from grpc import UnaryUnaryClientInterceptor + from grpc import UnaryStreamClientInterceptor + + +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.metrics_interceptor" + + +def _with_operation_from_metadata(func): + """ + Decorator for interceptor methods to extract the active operation + from metadata and pass it to the decorated function. + """ + @wraps(func) + def wrapper(self, continuation, client_call_details, request): + key = next((m[1] for m in client_call_details.metadata if m[0] == OPERATION_INTERCEPTOR_METADATA_KEY), None) + operation: "ActiveOperationMetric" = self.operation_map.get(key) + if operation: + # start a new attempt if not started + if operation.state != OperationState.ACTIVE_ATTEMPT: + operation.start_attempt() + # wrap continuation in logic to process the operation + return func(self, operation, continuation, client_call_details, request) + else: + # if operation not found, return unwrapped continuation + return continuation(client_call_details, request) + return wrapper + + +@CrossSync.convert_class( + sync_name="BigtableMetricsInterceptor" +) +class AsyncBigtableMetricsInterceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor): + """ + An async gRPC interceptor to add client metadata and print server metadata. + """ + + def __init__(self): + super().__init__() + self.operation_map = {} + + def register_operation(self, operation): + """ + Register an operation object to be tracked my the interceptor + + When registered, the operation will receive metadata updates: + - start_attempt if attempt not started when rpc is being sent + - add_response_metadata after call is complete + - end_attempt_with_status if attempt receives an error + + The interceptor will register itself as a handeler for the operation, + so it can unregister the operation when it is complete + """ + self.operation_map[operation.uuid] = operation + operation.handlers.append(self) + + def on_operation_complete(self, op): + del self.operation_map[op.uuid] + + def on_attempt_complete(self, attempt, operation): + pass + + @CrossSync.convert + @_with_operation_from_metadata + async def intercept_unary_unary(self, operation, continuation, client_call_details, request): + encountered_exc: Exception | None = None + call = None + try: + call = await continuation(client_call_details, request) + return call + except Exception as e: + encountered_exc = e + raise + finally: + if call is not None: + metadata = ( + await call.trailing_metadata() + + await call.initial_metadata() + ) + operation.add_response_metadata(metadata) + if encountered_exc is not None: + # end attempt. If it succeeded, let higher levels decide when to end operation + operation.end_attempt_with_status(encountered_exc) + + @CrossSync.convert + @_with_operation_from_metadata + async def intercept_unary_stream(self, operation, continuation, client_call_details, request): + async def response_wrapper(call): + encountered_exc = None + try: + async for response in call: + yield response + + except Exception as e: + encountered_exc = e + raise + finally: + metadata = ( + await call.trailing_metadata() + + await call.initial_metadata() + ) + operation.add_response_metadata(metadata) + if encountered_exc is not None: + # end attempt. If it succeeded, let higher levels decide when to end operation + operation.end_attempt_with_status(encountered_exc) + + return response_wrapper(await continuation(client_call_details, request)) \ No newline at end of file diff --git a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py new file mode 100644 index 000000000..4f9d8aa21 --- /dev/null +++ b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py @@ -0,0 +1,120 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + + +# This file is automatically generated by CrossSync. Do not edit manually. + +from functools import wraps +from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, +) +from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric +from google.cloud.bigtable.data._metrics.data_model import OperationState +from grpc import UnaryUnaryClientInterceptor +from grpc import UnaryStreamClientInterceptor + + +def _with_operation_from_metadata(func): + """Decorator for interceptor methods to extract the active operation + from metadata and pass it to the decorated function.""" + + @wraps(func) + def wrapper(self, continuation, client_call_details, request): + key = next( + ( + m[1] + for m in client_call_details.metadata + if m[0] == OPERATION_INTERCEPTOR_METADATA_KEY + ), + None, + ) + operation: "ActiveOperationMetric" = self.operation_map.get(key) + if operation: + if operation.state != OperationState.ACTIVE_ATTEMPT: + operation.start_attempt() + return func(self, operation, continuation, client_call_details, request) + else: + return continuation(client_call_details, request) + + return wrapper + + +class BigtableMetricsInterceptor( + UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor +): + """ + An async gRPC interceptor to add client metadata and print server metadata. + """ + + def __init__(self): + super().__init__() + self.operation_map = {} + + def register_operation(self, operation): + """Register an operation object to be tracked my the interceptor + + When registered, the operation will receive metadata updates: + - start_attempt if attempt not started when rpc is being sent + - add_response_metadata after call is complete + - end_attempt_with_status if attempt receives an error + + The interceptor will register itself as a handeler for the operation, + so it can unregister the operation when it is complete""" + self.operation_map[operation.uuid] = operation + operation.handlers.append(self) + + def on_operation_complete(self, op): + del self.operation_map[op.uuid] + + def on_attempt_complete(self, attempt, operation): + pass + + @_with_operation_from_metadata + def intercept_unary_unary( + self, operation, continuation, client_call_details, request + ): + encountered_exc: Exception | None = None + call = None + try: + call = continuation(client_call_details, request) + return call + except Exception as e: + encountered_exc = e + raise + finally: + if call is not None: + metadata = call.trailing_metadata() + call.initial_metadata() + operation.add_response_metadata(metadata) + if encountered_exc is not None: + operation.end_attempt_with_status(encountered_exc) + + @_with_operation_from_metadata + def intercept_unary_stream( + self, operation, continuation, client_call_details, request + ): + def response_wrapper(call): + encountered_exc = None + try: + for response in call: + yield response + except Exception as e: + encountered_exc = e + raise + finally: + metadata = call.trailing_metadata() + call.initial_metadata() + operation.add_response_metadata(metadata) + if encountered_exc is not None: + operation.end_attempt_with_status(encountered_exc) + + return response_wrapper(continuation(client_call_details, request)) From e4f82388f12c069944c22fd0ed49e4de9df0f030 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Jul 2025 16:07:52 -0700 Subject: [PATCH 17/60] added metrics to client --- google/cloud/bigtable/data/_async/client.py | 24 +++++++++++++++++++ .../bigtable/data/_sync_autogen/client.py | 15 ++++++++++++ 2 files changed, 39 insertions(+) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6ee21b554..9a4ca566c 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -84,6 +84,7 @@ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data._metrics import BigtableClientSideMetricsController from google.cloud.bigtable.data._cross_sync import CrossSync @@ -93,12 +94,14 @@ BigtableGrpcAsyncIOTransport as TransportType, ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE + from google.cloud.bigtable.data._async.metrics_interceptor import AsyncBigtableMetricsInterceptor as MetricInterceptorType else: from typing import Iterable # noqa: F401 from grpc import insecure_channel from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import BigtableMetricsInterceptor as MetricInterceptorType if TYPE_CHECKING: @@ -223,6 +226,7 @@ def __init__( self._executor = ( concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None ) + self._interceptor = MetricInterceptorType() if self._emulator_host is None: # attempt to start background channel refresh tasks try: @@ -882,6 +886,26 @@ def __init__( ) self.default_retryable_errors = default_retryable_errors or () + self._metrics = BigtableClientSideMetricsController( + client._interceptor, + project_id=self.client.project, + instance_id=instance_id, + table_id=table_id, + app_profile_id=app_profile_id, + ) + # TODO: simplify interceptors + if CrossSync.is_async: + client.transport.grpc_channel._unary_unary_interceptors.append( + self._metrics.interceptor + ) + client.transport.grpc_channel._unary_stream_interceptors.append( + self._metrics.interceptor + ) + else: + client.transport.grpc_channel = intercept_channel( + self._metrics.interceptor, client.transport.grpc_channel + ) + try: self._register_instance_future = CrossSync.create_task( self.client._register_instance, diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index b36bf359a..36d50b889 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -72,6 +72,7 @@ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data._metrics import BigtableClientSideMetricsController from google.cloud.bigtable.data._cross_sync import CrossSync from typing import Iterable from grpc import insecure_channel @@ -80,6 +81,9 @@ BigtableGrpcTransport as TransportType, ) from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE +from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor as MetricInterceptorType, +) if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -169,6 +173,7 @@ def __init__( if not CrossSync._Sync_Impl.is_async else None ) + self._interceptor = MetricInterceptorType() if self._emulator_host is None: try: self._start_background_channel_refresh() @@ -686,6 +691,16 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () + self._metrics = BigtableClientSideMetricsController( + client._interceptor, + project_id=self.client.project, + instance_id=instance_id, + table_id=table_id, + app_profile_id=app_profile_id, + ) + client.transport.grpc_channel = intercept_channel( + self._metrics.interceptor, client.transport.grpc_channel + ) try: self._register_instance_future = CrossSync._Sync_Impl.create_task( self.client._register_instance, From fcb062e4bda29732e7bcc5a733595c45b83eaa9c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Aug 2025 15:38:43 -0700 Subject: [PATCH 18/60] fixed lint --- google/cloud/bigtable/data/_async/_swappable_channel.py | 1 + tests/system/data/test_system_autogen.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_async/_swappable_channel.py b/google/cloud/bigtable/data/_async/_swappable_channel.py index e590e1234..bbc9a0d47 100644 --- a/google/cloud/bigtable/data/_async/_swappable_channel.py +++ b/google/cloud/bigtable/data/_async/_swappable_channel.py @@ -102,6 +102,7 @@ class AsyncSwappableChannel(_AsyncWrappedChannel): - channel_fn: a nullary function that returns a new channel instance. It should be a partial with all channel configuration arguments built-in """ + def __init__(self, channel_fn: Callable[[], Channel]): self._channel_fn = channel_fn self._channel = channel_fn() diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 406e73038..441114c09 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -26,10 +26,10 @@ from google.cloud.environment_vars import BIGTABLE_EMULATOR from google.type import date_pb2 from google.cloud.bigtable.data._cross_sync import CrossSync +from . import TEST_FAMILY, TEST_FAMILY_2, TEST_AGGREGATE_FAMILY from google.cloud.bigtable_v2.services.bigtable.transports.grpc import ( _LoggingClientInterceptor as GapicInterceptor, ) -from . import TEST_FAMILY, TEST_FAMILY_2, TEST_AGGREGATE_FAMILY TARGETS = ["table"] if not os.environ.get(BIGTABLE_EMULATOR): From d155f8ae562cc6f49fdc474799c398e9a061bd46 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Aug 2025 17:07:53 -0700 Subject: [PATCH 19/60] set up channel interceptions --- google/cloud/bigtable/data/_async/client.py | 34 +++++++++---------- .../data/_async/metrics_interceptor.py | 1 + .../bigtable/data/_sync_autogen/client.py | 18 +++++----- .../data/_sync_autogen/metrics_interceptor.py | 2 +- tests/system/data/test_system_async.py | 21 ++++++++---- tests/system/data/test_system_autogen.py | 6 ++-- tests/unit/data/_async/test_client.py | 5 +++ tests/unit/data/_sync_autogen/test_client.py | 7 ++++ 8 files changed, 60 insertions(+), 34 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 5d8602e5d..2e83c138b 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -19,6 +19,7 @@ cast, Any, AsyncIterable, + Callable, Optional, Set, Sequence, @@ -104,6 +105,7 @@ else: from typing import Iterable # noqa: F401 from grpc import insecure_channel + from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE @@ -206,7 +208,7 @@ def __init__( credentials = google.auth.credentials.AnonymousCredentials() if project is None: project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT - + self._metrics_interceptor = MetricInterceptorType() # initialize client ClientWithProject.__init__( self, @@ -234,7 +236,6 @@ def __init__( self._executor: concurrent.futures.ThreadPoolExecutor | None = ( concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None ) - self._interceptor = MetricInterceptorType() if self._emulator_host is None: # attempt to start background channel refresh tasks try: @@ -263,12 +264,23 @@ def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel: Returns: a custom wrapped swappable channel """ + create_channel_fn: Callable[[], Channel] if self._emulator_host is not None: # emulators use insecure channel create_channel_fn = partial(insecure_channel, self._emulator_host) - else: + elif CrossSync.is_async: create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) - return AsyncSwappableChannel(create_channel_fn) + else: + # attach sync interceptors in create_channel_fn + create_channel_fn = lambda: intercept_channel( + TransportType.create_channel(*args, **kwargs), self._metrics_interceptor + ) + new_channel = AsyncSwappableChannel(create_channel_fn) + if CrossSync.is_async: + # attach async interceptors + new_channel._unary_unary_interceptors.append(self._metrics_interceptor) + new_channel._unary_stream_interceptors.append(self._metrics_interceptor) + return new_channel @staticmethod def _client_version() -> str: @@ -922,24 +934,12 @@ def __init__( ) self._metrics = BigtableClientSideMetricsController( - client._interceptor, + client._metrics_interceptor, project_id=self.client.project, instance_id=instance_id, table_id=table_id, app_profile_id=app_profile_id, ) - # TODO: simplify interceptors - if CrossSync.is_async: - client.transport.grpc_channel._unary_unary_interceptors.append( - self._metrics.interceptor - ) - client.transport.grpc_channel._unary_stream_interceptors.append( - self._metrics.interceptor - ) - else: - client.transport.grpc_channel = intercept_channel( - self._metrics.interceptor, client.transport.grpc_channel - ) try: self._register_instance_future = CrossSync.create_task( diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 817782105..65a5d085b 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +from __future__ import annotations import time from typing import Any, Callable diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 4e3147c6f..0c85b3378 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -17,7 +17,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING +from typing import cast, Any, Callable, Optional, Set, Sequence, TYPE_CHECKING import abc import time import warnings @@ -76,6 +76,7 @@ from google.cloud.bigtable.data._cross_sync import CrossSync from typing import Iterable from grpc import insecure_channel +from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import ( BigtableGrpcTransport as TransportType, ) @@ -147,6 +148,7 @@ def __init__( credentials = google.auth.credentials.AnonymousCredentials() if project is None: project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT + self._metrics_interceptor = MetricInterceptorType() ClientWithProject.__init__( self, credentials=credentials, @@ -172,7 +174,6 @@ def __init__( if not CrossSync._Sync_Impl.is_async else None ) - self._interceptor = MetricInterceptorType() if self._emulator_host is None: try: self._start_background_channel_refresh() @@ -196,11 +197,15 @@ def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel: - **kwargs: keyword arguments passed by the gapic layer to create a new channel with Returns: a custom wrapped swappable channel""" + create_channel_fn: Callable[[], Channel] if self._emulator_host is not None: create_channel_fn = partial(insecure_channel, self._emulator_host) else: - create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) - return SwappableChannel(create_channel_fn) + create_channel_fn = lambda: intercept_channel( + TransportType.create_channel(*args, **kwargs), self._metrics_interceptor + ) + new_channel = SwappableChannel(create_channel_fn) + return new_channel @staticmethod def _client_version() -> str: @@ -723,15 +728,12 @@ def __init__( default_retryable_errors or () ) self._metrics = BigtableClientSideMetricsController( - client._interceptor, + client._metrics_interceptor, project_id=self.client.project, instance_id=instance_id, table_id=table_id, app_profile_id=app_profile_id, ) - client.transport.grpc_channel = intercept_channel( - self._metrics.interceptor, client.transport.grpc_channel - ) try: self._register_instance_future = CrossSync._Sync_Impl.create_task( self.client._register_instance, diff --git a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py index 4f9d8aa21..23a0870cd 100644 --- a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License - # This file is automatically generated by CrossSync. Do not edit manually. +from __future__ import annotations from functools import wraps from google.cloud.bigtable.data._metrics.data_model import ( OPERATION_INTERCEPTOR_METADATA_KEY, diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 99b739724..343d03269 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -292,23 +292,32 @@ async def test_channel_refresh(self, table_id, instance_id, temp_rows): async with client.get_table(instance_id, table_id) as table: rows = await table.read_rows({}) channel_wrapper = client.transport.grpc_channel - first_channel = client.transport.grpc_channel._channel + first_channel = channel_wrapper._channel assert len(rows) == 2 await CrossSync.sleep(2) rows_after_refresh = await table.read_rows({}) assert len(rows_after_refresh) == 2 assert client.transport.grpc_channel is channel_wrapper - assert client.transport.grpc_channel._channel is not first_channel - # ensure gapic's logging interceptor is still active + updated_channel = channel_wrapper._channel + assert channel_wrapper._channel is not first_channel + # ensure interceptors are kept (gapic's logging interceptor, and metric interceptor) if CrossSync.is_async: - interceptors = ( - client.transport.grpc_channel._channel._unary_unary_interceptors + unary_interceptors = ( + updated_channel._unary_unary_interceptors ) - assert GapicInterceptor in [type(i) for i in interceptors] + assert len(unary_interceptors) == 2 + assert GapicInterceptor in [type(i) for i in unary_interceptors] + assert client._metrics_interceptor in unary_interceptors + stream_interceptors = ( + updated_channel._unary_stream_interceptors + ) + assert len(stream_interceptors) == 1 + assert client._metrics_interceptor in stream_interceptors else: assert isinstance( client.transport._logged_channel._interceptor, GapicInterceptor ) + assert updated_channel._interceptor == client._metrics_interceptor finally: await client.close() diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 441114c09..9ff1b9864 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -235,16 +235,18 @@ def test_channel_refresh(self, table_id, instance_id, temp_rows): with client.get_table(instance_id, table_id) as table: rows = table.read_rows({}) channel_wrapper = client.transport.grpc_channel - first_channel = client.transport.grpc_channel._channel + first_channel = channel_wrapper._channel assert len(rows) == 2 CrossSync._Sync_Impl.sleep(2) rows_after_refresh = table.read_rows({}) assert len(rows_after_refresh) == 2 assert client.transport.grpc_channel is channel_wrapper - assert client.transport.grpc_channel._channel is not first_channel + updated_channel = channel_wrapper._channel + assert channel_wrapper._channel is not first_channel assert isinstance( client.transport._logged_channel._interceptor, GapicInterceptor ) + assert updated_channel._interceptor == client._metrics_interceptor finally: client.close() diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index be3d149d0..13f23863d 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -54,18 +54,22 @@ from google.cloud.bigtable.data._async._swappable_channel import ( AsyncSwappableChannel, ) + from google.cloud.bigtable.data._async.metrics_interceptor import AsyncBigtableMetricsInterceptor CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) CrossSync.add_mapping("SwappableChannel", AsyncSwappableChannel) + CrossSync.add_mapping("MetricsInterceptor", AsyncBigtableMetricsInterceptor) else: from google.api_core import grpc_helpers from google.cloud.bigtable.data._sync_autogen.client import Table # noqa: F401 from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( SwappableChannel, ) + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import BigtableMetricsInterceptor CrossSync.add_mapping("grpc_helpers", grpc_helpers) CrossSync.add_mapping("SwappableChannel", SwappableChannel) + CrossSync.add_mapping("MetricsInterceptor", BigtableMetricsInterceptor) __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_client" @@ -113,6 +117,7 @@ async def test_ctor(self): assert not client._active_instances assert client._channel_refresh_task is not None assert client.transport._credentials == expected_credentials + assert isinstance(client._metrics_interceptor, CrossSync.MetricsInterceptor) await client.close() @CrossSync.pytest diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index 36146d6ee..bfa4911d5 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -46,9 +46,13 @@ ) from google.api_core import grpc_helpers from google.cloud.bigtable.data._sync_autogen._swappable_channel import SwappableChannel +from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor, +) CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers) CrossSync._Sync_Impl.add_mapping("SwappableChannel", SwappableChannel) +CrossSync._Sync_Impl.add_mapping("MetricsInterceptor", BigtableMetricsInterceptor) @CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") @@ -84,6 +88,9 @@ def test_ctor(self): assert not client._active_instances assert client._channel_refresh_task is not None assert client.transport._credentials == expected_credentials + assert isinstance( + client._metrics_interceptor, CrossSync._Sync_Impl.MetricsInterceptor + ) client.close() def test_ctor_super_inits(self): From 9fece96ca3f2b6911b1a36e99cdf822a3e89ad3e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Aug 2025 17:32:12 -0700 Subject: [PATCH 20/60] added TrackedBackoffGenerator --- google/cloud/bigtable/data/_helpers.py | 30 +++++++++++ .../bigtable/data/_metrics/data_model.py | 4 +- tests/unit/data/test__helpers.py | 50 +++++++++++++++++++ 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 424a34486..68ace0174 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -23,6 +23,7 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.api_core import exceptions as core_exceptions +from google.api_core.retry import exponential_sleep_generator from google.api_core.retry import RetryFailureReason from google.cloud.bigtable.data.exceptions import RetryExceptionGroup @@ -248,3 +249,32 @@ def _get_retryable_errors( call_codes = table.default_mutate_rows_retryable_errors return [_get_error_type(e) for e in call_codes] + + +class TrackedBackoffGenerator: + """ + Generator class for exponential backoff sleep times. + This implementation builds on top of api_core.retries.exponential_sleep_generator, + adding the ability to retrieve previous values using get_attempt_backoff(idx). + This is used by the Metrics class to track the sleep times used for each attempt. + """ + + def __init__(self, initial=0.01, maximum=60, multiplier=2): + self.history = [] + self.subgenerator = exponential_sleep_generator( + initial=initial, maximum=maximum, multiplier=multiplier + ) + + def __iter__(self): + return self + + def __next__(self) -> float: + next_backoff = next(self.subgenerator) + self.history.append(next_backoff) + return next_backoff + + def get_attempt_backoff(self, attempt_idx) -> float: + """ + returns the backoff time for a specific attempt index, starting at 0. + """ + return self.history[attempt_idx] \ No newline at end of file diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index b48686e10..fff3198ba 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler - from google.cloud.bigtable.data._helpers import BackoffGenerator + from google.cloud.bigtable.data._helpers import TrackedBackoffGenerator LOGGER = logging.getLogger(__name__) @@ -144,7 +144,7 @@ class ActiveOperationMetric: op_type: OperationType uuid: str = str(uuid.uuid4()) - backoff_generator: BackoffGenerator | None = None + backoff_generator: TrackedBackoffGenerator | None = None # keep monotonic timestamps for active operations start_time_ns: int = field(default_factory=time.monotonic_ns) active_attempt: ActiveAttemptMetric | None = None diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 96c726a20..3ea6f4aa7 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -266,3 +266,53 @@ def test_get_retryable_errors(self, input_codes, input_table, expected): setattr(fake_table, f"{key}_retryable_errors", input_table[key]) result = _helpers._get_retryable_errors(input_codes, fake_table) assert result == expected + + +class TestTrackedBackoffGenerator: + + def test_tracked_backoff_generator_history(self): + """ + Should be able to retrieve historical results from backoff generator + """ + generator = _helpers.TrackedBackoffGenerator(initial=0, multiplier=2, maximum=10) + got_list = [next(generator) for _ in range(20)] + + # check all values are correct + for i in range(19, 0, -1): + assert generator.get_attempt_backoff(i) == got_list[i] + # check a random value out of order + assert generator.get_attempt_backoff(5) == got_list[5] + + @mock.patch("random.uniform", side_effect=lambda a, b: b) + def test_tracked_backoff_generator_defaults(self, mock_uniform): + """ + Should generate values with default parameters + + initial=0.01, multiplier=2, maximum=60 + """ + generator = _helpers.TrackedBackoffGenerator() + expected_values = [0.01, 0.02, 0.04, 0.08, 0.16] + for expected in expected_values: + assert next(generator) == pytest.approx(expected) + + @mock.patch("random.uniform", side_effect=lambda a, b: b) + def test_tracked_backoff_generator_with_maximum(self, mock_uniform): + """ + Should cap the backoff at the maximum value + """ + generator = _helpers.TrackedBackoffGenerator(initial=1, multiplier=2, maximum=5) + expected_values = [1, 2, 4, 5, 5, 5] + for expected in expected_values: + assert next(generator) == expected + + def test_get_attempt_backoff_out_of_bounds(self): + """ + get_attempt_backoff should raise IndexError for out of bounds index + """ + generator = _helpers.TrackedBackoffGenerator() + next(generator) + next(generator) + with pytest.raises(IndexError): + generator.get_attempt_backoff(2) + with pytest.raises(IndexError): + generator.get_attempt_backoff(-3) From aec2577638a8b53200e1682bc93b3c24a213865b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Aug 2025 17:39:21 -0700 Subject: [PATCH 21/60] fixed lint --- google/cloud/bigtable/data/_async/client.py | 17 +++++--- .../data/_async/metrics_interceptor.py | 41 ++++++++++++------- google/cloud/bigtable/data/_helpers.py | 2 +- .../bigtable/data/_metrics/data_model.py | 6 +-- .../data/_metrics/metrics_controller.py | 18 +++++--- .../bigtable/data/_sync_autogen/client.py | 10 +++-- tests/system/data/test_system_async.py | 8 +--- tests/unit/data/_async/test_client.py | 8 +++- tests/unit/data/test__helpers.py | 5 ++- 9 files changed, 72 insertions(+), 43 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 2e83c138b..d63909282 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -98,7 +98,9 @@ BigtableAsyncClient as GapicClient, ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._async.metrics_interceptor import AsyncBigtableMetricsInterceptor as MetricInterceptorType + from google.cloud.bigtable.data._async.metrics_interceptor import ( + AsyncBigtableMetricsInterceptor as MetricInterceptorType, + ) from google.cloud.bigtable.data._async._swappable_channel import ( AsyncSwappableChannel, ) @@ -109,7 +111,9 @@ from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import BigtableMetricsInterceptor as MetricInterceptorType + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor as MetricInterceptorType, + ) from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( # noqa: F401 SwappableChannel, ) @@ -272,9 +276,12 @@ def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel: create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) else: # attach sync interceptors in create_channel_fn - create_channel_fn = lambda: intercept_channel( - TransportType.create_channel(*args, **kwargs), self._metrics_interceptor - ) + def create_channel_fn(): + return intercept_channel( + TransportType.create_channel(*args, **kwargs), + self._metrics_interceptor, + ) + new_channel = AsyncSwappableChannel(create_channel_fn) if CrossSync.is_async: # attach async interceptors diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 65a5d085b..6b807f93d 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -13,10 +13,10 @@ # limitations under the License from __future__ import annotations -import time -from typing import Any, Callable from functools import wraps -from google.cloud.bigtable.data._metrics.data_model import OPERATION_INTERCEPTOR_METADATA_KEY +from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, +) from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric from google.cloud.bigtable.data._metrics.data_model import OperationState @@ -38,9 +38,17 @@ def _with_operation_from_metadata(func): Decorator for interceptor methods to extract the active operation from metadata and pass it to the decorated function. """ + @wraps(func) def wrapper(self, continuation, client_call_details, request): - key = next((m[1] for m in client_call_details.metadata if m[0] == OPERATION_INTERCEPTOR_METADATA_KEY), None) + key = next( + ( + m[1] + for m in client_call_details.metadata + if m[0] == OPERATION_INTERCEPTOR_METADATA_KEY + ), + None, + ) operation: "ActiveOperationMetric" = self.operation_map.get(key) if operation: # start a new attempt if not started @@ -51,13 +59,14 @@ def wrapper(self, continuation, client_call_details, request): else: # if operation not found, return unwrapped continuation return continuation(client_call_details, request) + return wrapper -@CrossSync.convert_class( - sync_name="BigtableMetricsInterceptor" -) -class AsyncBigtableMetricsInterceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor): +@CrossSync.convert_class(sync_name="BigtableMetricsInterceptor") +class AsyncBigtableMetricsInterceptor( + UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor +): """ An async gRPC interceptor to add client metadata and print server metadata. """ @@ -89,7 +98,9 @@ def on_attempt_complete(self, attempt, operation): @CrossSync.convert @_with_operation_from_metadata - async def intercept_unary_unary(self, operation, continuation, client_call_details, request): + async def intercept_unary_unary( + self, operation, continuation, client_call_details, request + ): encountered_exc: Exception | None = None call = None try: @@ -101,8 +112,7 @@ async def intercept_unary_unary(self, operation, continuation, client_call_detai finally: if call is not None: metadata = ( - await call.trailing_metadata() - + await call.initial_metadata() + await call.trailing_metadata() + await call.initial_metadata() ) operation.add_response_metadata(metadata) if encountered_exc is not None: @@ -111,7 +121,9 @@ async def intercept_unary_unary(self, operation, continuation, client_call_detai @CrossSync.convert @_with_operation_from_metadata - async def intercept_unary_stream(self, operation, continuation, client_call_details, request): + async def intercept_unary_stream( + self, operation, continuation, client_call_details, request + ): async def response_wrapper(call): encountered_exc = None try: @@ -123,12 +135,11 @@ async def response_wrapper(call): raise finally: metadata = ( - await call.trailing_metadata() - + await call.initial_metadata() + await call.trailing_metadata() + await call.initial_metadata() ) operation.add_response_metadata(metadata) if encountered_exc is not None: # end attempt. If it succeeded, let higher levels decide when to end operation operation.end_attempt_with_status(encountered_exc) - return response_wrapper(await continuation(client_call_details, request)) \ No newline at end of file + return response_wrapper(await continuation(client_call_details, request)) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 68ace0174..c9c3bd706 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -277,4 +277,4 @@ def get_attempt_backoff(self, attempt_idx) -> float: """ returns the backoff time for a specific attempt index, starting at 0. """ - return self.history[attempt_idx] \ No newline at end of file + return self.history[attempt_idx] diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index fff3198ba..c29ac7f43 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Callable, Any, Tuple, cast, TYPE_CHECKING +from typing import Callable, Tuple, cast, TYPE_CHECKING import time import re @@ -48,7 +48,7 @@ INVALID_STATE_ERROR = "Invalid state for {}: {}" -OPERATION_INTERCEPTOR_METADATA_KEY = 'x-goog-operation-key' +OPERATION_INTERCEPTOR_METADATA_KEY = "x-goog-operation-key" class OperationType(Enum): @@ -443,4 +443,4 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_val is None: self.end_with_success() else: - self.end_with_status(exc_val) \ No newline at end of file + self.end_with_status(exc_val) diff --git a/google/cloud/bigtable/data/_metrics/metrics_controller.py b/google/cloud/bigtable/data/_metrics/metrics_controller.py index 52d669227..169109e28 100644 --- a/google/cloud/bigtable/data/_metrics/metrics_controller.py +++ b/google/cloud/bigtable/data/_metrics/metrics_controller.py @@ -20,8 +20,13 @@ from google.cloud.bigtable.data._metrics.data_model import OperationType if TYPE_CHECKING: - from google.cloud.bigtable.data._async.metrics_interceptor import AsyncBigtableMetricsInterceptor - from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import BigtableMetricsInterceptor + from google.cloud.bigtable.data._async.metrics_interceptor import ( + AsyncBigtableMetricsInterceptor, + ) + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor, + ) + class BigtableClientSideMetricsController: """ @@ -31,10 +36,11 @@ class BigtableClientSideMetricsController: registered with the handlers associated with this controller. """ - def __init__(self, - interceptor: AsyncBigtableMetricsInterceptor | BigtableMetricsInterceptor, + def __init__( + self, + interceptor: AsyncBigtableMetricsInterceptor | BigtableMetricsInterceptor, handlers: list[MetricsHandler] | None = None, - **kwargs + **kwargs, ): """ Initializes the metrics controller. @@ -69,4 +75,4 @@ def create_operation( handlers = self.handlers + kwargs.pop("handlers", []) new_op = ActiveOperationMetric(op_type, **kwargs, handlers=handlers) self.interceptor.register_operation(new_op) - return new_op \ No newline at end of file + return new_op diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 0c85b3378..a7a9b16b4 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -201,9 +201,13 @@ def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel: if self._emulator_host is not None: create_channel_fn = partial(insecure_channel, self._emulator_host) else: - create_channel_fn = lambda: intercept_channel( - TransportType.create_channel(*args, **kwargs), self._metrics_interceptor - ) + + def create_channel_fn(): + return intercept_channel( + TransportType.create_channel(*args, **kwargs), + self._metrics_interceptor, + ) + new_channel = SwappableChannel(create_channel_fn) return new_channel diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 343d03269..b4e661e6c 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -302,15 +302,11 @@ async def test_channel_refresh(self, table_id, instance_id, temp_rows): assert channel_wrapper._channel is not first_channel # ensure interceptors are kept (gapic's logging interceptor, and metric interceptor) if CrossSync.is_async: - unary_interceptors = ( - updated_channel._unary_unary_interceptors - ) + unary_interceptors = updated_channel._unary_unary_interceptors assert len(unary_interceptors) == 2 assert GapicInterceptor in [type(i) for i in unary_interceptors] assert client._metrics_interceptor in unary_interceptors - stream_interceptors = ( - updated_channel._unary_stream_interceptors - ) + stream_interceptors = updated_channel._unary_stream_interceptors assert len(stream_interceptors) == 1 assert client._metrics_interceptor in stream_interceptors else: diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 13f23863d..08efb8db8 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -54,7 +54,9 @@ from google.cloud.bigtable.data._async._swappable_channel import ( AsyncSwappableChannel, ) - from google.cloud.bigtable.data._async.metrics_interceptor import AsyncBigtableMetricsInterceptor + from google.cloud.bigtable.data._async.metrics_interceptor import ( + AsyncBigtableMetricsInterceptor, + ) CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) CrossSync.add_mapping("SwappableChannel", AsyncSwappableChannel) @@ -65,7 +67,9 @@ from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( SwappableChannel, ) - from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import BigtableMetricsInterceptor + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor, + ) CrossSync.add_mapping("grpc_helpers", grpc_helpers) CrossSync.add_mapping("SwappableChannel", SwappableChannel) diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 3ea6f4aa7..fda0b7686 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -269,12 +269,13 @@ def test_get_retryable_errors(self, input_codes, input_table, expected): class TestTrackedBackoffGenerator: - def test_tracked_backoff_generator_history(self): """ Should be able to retrieve historical results from backoff generator """ - generator = _helpers.TrackedBackoffGenerator(initial=0, multiplier=2, maximum=10) + generator = _helpers.TrackedBackoffGenerator( + initial=0, multiplier=2, maximum=10 + ) got_list = [next(generator) for _ in range(20)] # check all values are correct From ec4e847d5a1b54e73aaeaec384391f22271666ec Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 1 Aug 2025 17:43:25 -0700 Subject: [PATCH 22/60] fixed import --- tests/unit/data/_metrics/test_data_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/data/_metrics/test_data_model.py b/tests/unit/data/_metrics/test_data_model.py index 0a136075f..09526789b 100644 --- a/tests/unit/data/_metrics/test_data_model.py +++ b/tests/unit/data/_metrics/test_data_model.py @@ -213,9 +213,9 @@ def test_start_attempt_with_backoff_generator(self): If operation has a backoff generator, it should be used to attach backoff times to attempts """ - from google.cloud.bigtable.data._helpers import BackoffGenerator + from google.cloud.bigtable.data._helpers import TrackedBackoffGenerator - generator = BackoffGenerator() + generator = TrackedBackoffGenerator() # pre-seed generator with exepcted values generator.history = list(range(10)) metric = self._make_one(mock.Mock(), backoff_generator=generator) From 8f99e4e4d2ae90e09ecadf6980c10da1acaefcbd Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 6 Aug 2025 16:15:29 -0700 Subject: [PATCH 23/60] added operation.cancel --- .../bigtable/data/_metrics/data_model.py | 23 ++++--------------- .../bigtable/data/_metrics/handlers/_base.py | 5 +++- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index c29ac7f43..6df25a8ef 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -365,27 +365,12 @@ def end_with_success(self): """ return self.end_with_status(StatusCode.OK) - def build_wrapped_predicate( - self, inner_predicate: Callable[[Exception], bool] - ) -> Callable[[Exception], bool]: + def cancel(self): """ - Wrapps a predicate to include metrics tracking. Any call to the resulting predicate - is assumed to be an rpc failure, and will either mark the end of the active attempt - or the end of the operation. - - Args: - - predicate: The predicate to wrap. + Called to cancel an operation without processing emitting it. """ - - def wrapped_predicate(exc: Exception) -> bool: - inner_result = inner_predicate(exc) - if inner_result: - self.end_attempt_with_status(exc) - else: - self.end_with_status(exc) - return inner_result - - return wrapped_predicate + for handler in self.handlers: + handler.on_operation_canceled(self) @staticmethod def _exc_to_status(exc: Exception) -> StatusCode: diff --git a/google/cloud/bigtable/data/_metrics/handlers/_base.py b/google/cloud/bigtable/data/_metrics/handlers/_base.py index 72f5aa550..5c0e32201 100644 --- a/google/cloud/bigtable/data/_metrics/handlers/_base.py +++ b/google/cloud/bigtable/data/_metrics/handlers/_base.py @@ -29,7 +29,10 @@ def __init__(self, **kwargs): def on_operation_complete(self, op: CompletedOperationMetric) -> None: pass + def on_operation_canceled(self, op: ActiveOperationMetric) -> None: + pass + def on_attempt_complete( self, attempt: CompletedAttemptMetric, op: ActiveOperationMetric ) -> None: - pass + pass \ No newline at end of file From f8e6603d9f8e6e472adc1bcef4dcdbb3e8db179c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 6 Aug 2025 16:20:13 -0700 Subject: [PATCH 24/60] added operation cancelled to interceptor --- google/cloud/bigtable/data/_async/metrics_interceptor.py | 7 ++++--- google/cloud/bigtable/data/_metrics/data_model.py | 2 +- google/cloud/bigtable/data/_metrics/handlers/_base.py | 2 +- .../bigtable/data/_sync_autogen/metrics_interceptor.py | 7 ++++--- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 6b807f93d..dddca184b 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -19,6 +19,7 @@ ) from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric from google.cloud.bigtable.data._metrics.data_model import OperationState +from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler from google.cloud.bigtable.data._cross_sync import CrossSync @@ -65,7 +66,7 @@ def wrapper(self, continuation, client_call_details, request): @CrossSync.convert_class(sync_name="BigtableMetricsInterceptor") class AsyncBigtableMetricsInterceptor( - UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor + UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, MetricsHandler ): """ An async gRPC interceptor to add client metadata and print server metadata. @@ -93,8 +94,8 @@ def register_operation(self, operation): def on_operation_complete(self, op): del self.operation_map[op.uuid] - def on_attempt_complete(self, attempt, operation): - pass + def on_operation_cancelled(self, op): + self.on_operation_complete(op) @CrossSync.convert @_with_operation_from_metadata diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 6df25a8ef..9ff8b11ab 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -370,7 +370,7 @@ def cancel(self): Called to cancel an operation without processing emitting it. """ for handler in self.handlers: - handler.on_operation_canceled(self) + handler.on_operation_cancelled(self) @staticmethod def _exc_to_status(exc: Exception) -> StatusCode: diff --git a/google/cloud/bigtable/data/_metrics/handlers/_base.py b/google/cloud/bigtable/data/_metrics/handlers/_base.py index 5c0e32201..05132d618 100644 --- a/google/cloud/bigtable/data/_metrics/handlers/_base.py +++ b/google/cloud/bigtable/data/_metrics/handlers/_base.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): def on_operation_complete(self, op: CompletedOperationMetric) -> None: pass - def on_operation_canceled(self, op: ActiveOperationMetric) -> None: + def on_operation_cancelled(self, op: ActiveOperationMetric) -> None: pass def on_attempt_complete( diff --git a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py index 23a0870cd..21dd6752a 100644 --- a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py @@ -21,6 +21,7 @@ ) from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric from google.cloud.bigtable.data._metrics.data_model import OperationState +from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler from grpc import UnaryUnaryClientInterceptor from grpc import UnaryStreamClientInterceptor @@ -51,7 +52,7 @@ def wrapper(self, continuation, client_call_details, request): class BigtableMetricsInterceptor( - UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor + UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, MetricsHandler ): """ An async gRPC interceptor to add client metadata and print server metadata. @@ -77,8 +78,8 @@ def register_operation(self, operation): def on_operation_complete(self, op): del self.operation_map[op.uuid] - def on_attempt_complete(self, attempt, operation): - pass + def on_operation_cancelled(self, op): + self.on_operation_complete(op) @_with_operation_from_metadata def intercept_unary_unary( From f5e057e930498861b425350dbc6c08f3b8c80d83 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 6 Aug 2025 16:47:24 -0700 Subject: [PATCH 25/60] gave each operation a uuid --- google/cloud/bigtable/data/_async/metrics_interceptor.py | 3 ++- google/cloud/bigtable/data/_metrics/data_model.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index dddca184b..9ec5412b8 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -92,7 +92,8 @@ def register_operation(self, operation): operation.handlers.append(self) def on_operation_complete(self, op): - del self.operation_map[op.uuid] + if op.uuid in self.operation_map: + del self.operation_map[op.uuid] def on_operation_cancelled(self, op): self.on_operation_complete(op) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 9ff8b11ab..217572abf 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -143,7 +143,7 @@ class ActiveOperationMetric: """ op_type: OperationType - uuid: str = str(uuid.uuid4()) + uuid: str = field(default_factory=lambda: str(uuid.uuid4())) backoff_generator: TrackedBackoffGenerator | None = None # keep monotonic timestamps for active operations start_time_ns: int = field(default_factory=time.monotonic_ns) From 8c397bb17308ed888a8399cecc6238593ca4ed15 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 7 Aug 2025 15:00:58 -0700 Subject: [PATCH 26/60] return attempt metric on new attempt --- google/cloud/bigtable/data/_metrics/__init__.py | 6 ++++++ google/cloud/bigtable/data/_metrics/data_model.py | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_metrics/__init__.py b/google/cloud/bigtable/data/_metrics/__init__.py index 43b8b6139..20d36d4c8 100644 --- a/google/cloud/bigtable/data/_metrics/__init__.py +++ b/google/cloud/bigtable/data/_metrics/__init__.py @@ -17,9 +17,15 @@ from google.cloud.bigtable.data._metrics.data_model import OperationType from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric +from google.cloud.bigtable.data._metrics.data_model import ActiveAttemptMetric +from google.cloud.bigtable.data._metrics.data_model import CompletedOperationMetric +from google.cloud.bigtable.data._metrics.data_model import CompletedAttemptMetric __all__ = ( "BigtableClientSideMetricsController", "OperationType", "ActiveOperationMetric", + "ActiveAttemptMetric", + "CompletedOperationMetric", + "CompletedAttemptMetric", ) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 217572abf..cce141829 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -184,7 +184,7 @@ def start(self) -> None: return self._handle_error(INVALID_STATE_ERROR.format("start", self.state)) self.start_time_ns = time.monotonic_ns() - def start_attempt(self) -> None: + def start_attempt(self) -> ActiveAttemptMetric: """ Called to initiate a new attempt for the operation. @@ -210,6 +210,7 @@ def start_attempt(self) -> None: backoff_ns = 0 self.active_attempt = ActiveAttemptMetric(backoff_before_attempt_ns=backoff_ns) + return self.active_attempt def add_response_metadata(self, metadata: dict[str, bytes | str]) -> None: """ From 2c341989d14c289075abc25915fd74f0e2ea637e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 7 Aug 2025 15:01:57 -0700 Subject: [PATCH 27/60] use standard context manager --- google/cloud/bigtable/data/_metrics/data_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index cce141829..f5f6e739b 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -409,9 +409,9 @@ def _handle_error(message: str) -> None: full_message = f"Error in Bigtable Metrics: {message}" LOGGER.warning(full_message) - async def __aenter__(self): + def __aenter__(self): """ - Implements the async context manager protocol for wrapping unary calls + Implements the async manager protocol Using the operation's context manager provides assurances that the operation is always closed when complete, with the proper status code automaticallty @@ -419,9 +419,9 @@ async def __aenter__(self): """ return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + def __aexit__(self, exc_type, exc_val, exc_tb): """ - Implements the async context manager protocol for wrapping unary calls + Implements the context manager protocol The operation is automatically ended on exit, with the status determined by the exception type and value. From 9bd1e075de2e01847f5935340be9c211044b45e8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 7 Aug 2025 15:06:46 -0700 Subject: [PATCH 28/60] use default backoff generator --- google/cloud/bigtable/data/_metrics/data_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index f5f6e739b..6c7722f4d 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -144,7 +144,10 @@ class ActiveOperationMetric: op_type: OperationType uuid: str = field(default_factory=lambda: str(uuid.uuid4())) - backoff_generator: TrackedBackoffGenerator | None = None + # create a default backoff generator, initialized with standard default backoff values + backoff_generator: TrackedBackoffGenerator | None = field( + default_factory=lambda: TrackedBackoffGenerator(initial=0.01, maximum=60, multiplier=2) + ) # keep monotonic timestamps for active operations start_time_ns: int = field(default_factory=time.monotonic_ns) active_attempt: ActiveAttemptMetric | None = None From 96d1355c8507898242ad55fb4e6d932380e1c526 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 7 Aug 2025 15:26:25 -0700 Subject: [PATCH 29/60] require backoff; refactor check --- google/cloud/bigtable/data/_metrics/data_model.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 6c7722f4d..8290a4e6e 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -145,7 +145,7 @@ class ActiveOperationMetric: op_type: OperationType uuid: str = field(default_factory=lambda: str(uuid.uuid4())) # create a default backoff generator, initialized with standard default backoff values - backoff_generator: TrackedBackoffGenerator | None = field( + backoff_generator: TrackedBackoffGenerator = field( default_factory=lambda: TrackedBackoffGenerator(initial=0.01, maximum=60, multiplier=2) ) # keep monotonic timestamps for active operations @@ -201,15 +201,16 @@ def start_attempt(self) -> ActiveAttemptMetric: INVALID_STATE_ERROR.format("start_attempt", self.state) ) - # find backoff value - if self.backoff_generator and len(self.completed_attempts) > 0: - # find the attempt's backoff by sending attempt number to generator - # generator will return the backoff time in seconds, so convert to nanoseconds + try: + # find backoff value before this attempt + prev_attempt_idx = len(self.completed_attempts) - 1 backoff = self.backoff_generator.get_attempt_backoff( - len(self.completed_attempts) - 1 + prev_attempt_idx ) + # generator will return the backoff time in seconds, so convert to nanoseconds backoff_ns = int(backoff * 1e9) - else: + except IndexError: + # backoff value not found backoff_ns = 0 self.active_attempt = ActiveAttemptMetric(backoff_before_attempt_ns=backoff_ns) From de5d07bd492af255acb987cd42bd86f99fd7296c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 7 Aug 2025 15:36:04 -0700 Subject: [PATCH 30/60] fixed context manager naming; lint --- .../cloud/bigtable/data/_metrics/data_model.py | 16 ++++++++-------- .../bigtable/data/_metrics/handlers/_base.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 8290a4e6e..fbc25c8aa 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Callable, Tuple, cast, TYPE_CHECKING +from typing import Tuple, cast, TYPE_CHECKING import time import re @@ -146,7 +146,9 @@ class ActiveOperationMetric: uuid: str = field(default_factory=lambda: str(uuid.uuid4())) # create a default backoff generator, initialized with standard default backoff values backoff_generator: TrackedBackoffGenerator = field( - default_factory=lambda: TrackedBackoffGenerator(initial=0.01, maximum=60, multiplier=2) + default_factory=lambda: TrackedBackoffGenerator( + initial=0.01, maximum=60, multiplier=2 + ) ) # keep monotonic timestamps for active operations start_time_ns: int = field(default_factory=time.monotonic_ns) @@ -187,7 +189,7 @@ def start(self) -> None: return self._handle_error(INVALID_STATE_ERROR.format("start", self.state)) self.start_time_ns = time.monotonic_ns() - def start_attempt(self) -> ActiveAttemptMetric: + def start_attempt(self) -> ActiveAttemptMetric | None: """ Called to initiate a new attempt for the operation. @@ -204,9 +206,7 @@ def start_attempt(self) -> ActiveAttemptMetric: try: # find backoff value before this attempt prev_attempt_idx = len(self.completed_attempts) - 1 - backoff = self.backoff_generator.get_attempt_backoff( - prev_attempt_idx - ) + backoff = self.backoff_generator.get_attempt_backoff(prev_attempt_idx) # generator will return the backoff time in seconds, so convert to nanoseconds backoff_ns = int(backoff * 1e9) except IndexError: @@ -413,7 +413,7 @@ def _handle_error(message: str) -> None: full_message = f"Error in Bigtable Metrics: {message}" LOGGER.warning(full_message) - def __aenter__(self): + def __enter__(self): """ Implements the async manager protocol @@ -423,7 +423,7 @@ def __aenter__(self): """ return self - def __aexit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb): """ Implements the context manager protocol diff --git a/google/cloud/bigtable/data/_metrics/handlers/_base.py b/google/cloud/bigtable/data/_metrics/handlers/_base.py index 05132d618..64cc89b05 100644 --- a/google/cloud/bigtable/data/_metrics/handlers/_base.py +++ b/google/cloud/bigtable/data/_metrics/handlers/_base.py @@ -35,4 +35,4 @@ def on_operation_cancelled(self, op: ActiveOperationMetric) -> None: def on_attempt_complete( self, attempt: CompletedAttemptMetric, op: ActiveOperationMetric ) -> None: - pass \ No newline at end of file + pass From d73f379718ed881abda2237bf7c1107576e4a750 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 7 Aug 2025 15:58:00 -0700 Subject: [PATCH 31/60] moved first_response_latency to operation --- .../data/_async/metrics_interceptor.py | 9 ++++++ .../bigtable/data/_metrics/data_model.py | 28 ++++--------------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 9ec5412b8..e5777ea0f 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -13,6 +13,7 @@ # limitations under the License from __future__ import annotations +import time from functools import wraps from google.cloud.bigtable.data._metrics.data_model import ( OPERATION_INTERCEPTOR_METADATA_KEY, @@ -126,10 +127,18 @@ async def intercept_unary_unary( async def intercept_unary_stream( self, operation, continuation, client_call_details, request ): + # TODO: benchmark async def response_wrapper(call): + has_first_response = operation.first_response_latency is not None encountered_exc = None try: async for response in call: + # record time to first response. Currently only used for READ_ROWs + if not has_first_response: + operation.first_response_latency_ns = ( + time.monotonic_ns() - operation.start_time_ns + ) + has_first_response = True yield response except Exception as e: diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index fbc25c8aa..3a7754be3 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -83,7 +83,6 @@ class CompletedAttemptMetric: duration_ns: int end_status: StatusCode - first_response_latency_ns: int | None = None gfe_latency_ns: int | None = None application_blocking_time_ns: int = 0 backoff_before_attempt_ns: int = 0 @@ -108,6 +107,7 @@ class CompletedOperationMetric: cluster_id: str zone: str is_streaming: bool + first_response_latency_ns: int | None = None flow_throttling_time_ns: int = 0 @@ -120,9 +120,6 @@ class ActiveAttemptMetric: # keep monotonic timestamps for active attempts start_time_ns: int = field(default_factory=time.monotonic_ns) - # the time it takes to recieve the first response from the server, in nanoseconds - # currently only tracked for ReadRows - first_response_latency_ns: int | None = None # the time taken by the backend, in nanoseconds. Taken from response header gfe_latency_ns: int | None = None # time waiting on user to process the response, in nanoseconds @@ -159,6 +156,10 @@ class ActiveOperationMetric: is_streaming: bool = False # only True for read_rows operations was_completed: bool = False handlers: list[MetricsHandler] = field(default_factory=list) + # the time it takes to recieve the first response from the server, in nanoseconds + # attached by interceptor + # currently only tracked for ReadRows + first_response_latency_ns: int | None = None # time waiting on flow control, in nanoseconds flow_throttling_time_ns: int = 0 @@ -274,23 +275,6 @@ def _parse_response_metadata_blob(blob: bytes) -> Tuple[str, str] | None: # failed to parse metadata return None - def attempt_first_response(self) -> None: - """ - Called to mark the timestamp of the first completed response for the - active attempt. - - Assumes operation is in ACTIVE_ATTEMPT state. - """ - if self.state != OperationState.ACTIVE_ATTEMPT or self.active_attempt is None: - return self._handle_error( - INVALID_STATE_ERROR.format("attempt_first_response", self.state) - ) - if self.active_attempt.first_response_latency_ns is not None: - return self._handle_error("Attempt already received first response") - self.active_attempt.first_response_latency_ns = ( - time.monotonic_ns() - self.active_attempt.start_time_ns - ) - def end_attempt_with_status(self, status: StatusCode | Exception) -> None: """ Called to mark the end of an attempt for the operation. @@ -311,7 +295,6 @@ def end_attempt_with_status(self, status: StatusCode | Exception) -> None: if isinstance(status, Exception): status = self._exc_to_status(status) complete_attempt = CompletedAttemptMetric( - first_response_latency_ns=self.active_attempt.first_response_latency_ns, duration_ns=time.monotonic_ns() - self.active_attempt.start_time_ns, end_status=status, gfe_latency_ns=self.active_attempt.gfe_latency_ns, @@ -355,6 +338,7 @@ def end_with_status(self, status: StatusCode | Exception) -> None: cluster_id=self.cluster_id or DEFAULT_CLUSTER_ID, zone=self.zone or DEFAULT_ZONE, is_streaming=self.is_streaming, + first_response_latency_ns=self.first_response_latency_ns, flow_throttling_time_ns=self.flow_throttling_time_ns, ) for handler in self.handlers: From 4a4f80a91369af45b40db6689aa4eaed237a7f31 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 8 Aug 2025 11:13:02 -0700 Subject: [PATCH 32/60] fixed import --- google/cloud/bigtable/data/_metrics/data_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 3a7754be3..2bcd40021 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -28,11 +28,11 @@ import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable_v2.types.response_params import ResponseParams +from google.cloud.bigtable.data._helpers import TrackedBackoffGenerator from google.protobuf.message import DecodeError if TYPE_CHECKING: from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler - from google.cloud.bigtable.data._helpers import TrackedBackoffGenerator LOGGER = logging.getLogger(__name__) From 708a35aa76bedc0c3cb6e47f9787d8a056dbd3c8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Aug 2025 15:19:30 -0700 Subject: [PATCH 33/60] fixed broken unit tests --- google/cloud/bigtable/data/_helpers.py | 2 + .../data/_metrics/metrics_controller.py | 3 +- tests/unit/data/_metrics/test_data_model.py | 236 +----------------- .../data/_metrics/test_metrics_controller.py | 39 ++- 4 files changed, 42 insertions(+), 238 deletions(-) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index c9c3bd706..ba64ffab4 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -277,4 +277,6 @@ def get_attempt_backoff(self, attempt_idx) -> float: """ returns the backoff time for a specific attempt index, starting at 0. """ + if attempt_idx < 0: + raise IndexError("received negative attempt number") return self.history[attempt_idx] diff --git a/google/cloud/bigtable/data/_metrics/metrics_controller.py b/google/cloud/bigtable/data/_metrics/metrics_controller.py index 169109e28..f13590f7c 100644 --- a/google/cloud/bigtable/data/_metrics/metrics_controller.py +++ b/google/cloud/bigtable/data/_metrics/metrics_controller.py @@ -72,7 +72,6 @@ def create_operation( """ Creates a new operation and registers it with the subscribed handlers. """ - handlers = self.handlers + kwargs.pop("handlers", []) - new_op = ActiveOperationMetric(op_type, **kwargs, handlers=handlers) + new_op = ActiveOperationMetric(op_type, **kwargs, handlers=self.handlers) self.interceptor.register_operation(new_op) return new_op diff --git a/tests/unit/data/_metrics/test_data_model.py b/tests/unit/data/_metrics/test_data_model.py index 09526789b..b281ccec0 100644 --- a/tests/unit/data/_metrics/test_data_model.py +++ b/tests/unit/data/_metrics/test_data_model.py @@ -125,7 +125,6 @@ def test_state_machine_w_state(self): ("start", (), (State.CREATED,), None), ("start_attempt", (), (State.CREATED, State.BETWEEN_ATTEMPTS), None), ("add_response_metadata", ({},), (State.ACTIVE_ATTEMPT,), None), - ("attempt_first_response", (), (State.ACTIVE_ATTEMPT,), None), ("end_attempt_with_status", (mock.Mock(),), (State.ACTIVE_ATTEMPT,), None), ( "end_with_status", @@ -202,7 +201,6 @@ def test_start_attempt(self): assert ( abs(time.monotonic_ns() - metric.active_attempt.start_time_ns) < 1e6 ) # 1ms buffer - assert metric.active_attempt.first_response_latency_ns is None assert metric.active_attempt.gfe_latency_ns is None assert metric.active_attempt.grpc_throttling_time_ns == 0 # should be in ACTIVE_ATTEMPT state after completing @@ -219,8 +217,6 @@ def test_start_attempt_with_backoff_generator(self): # pre-seed generator with exepcted values generator.history = list(range(10)) metric = self._make_one(mock.Mock(), backoff_generator=generator) - # initialize generator - next(metric.backoff_generator) metric.start_attempt() assert len(metric.completed_attempts) == 0 # first attempt should always be 0 @@ -307,7 +303,7 @@ def test_add_response_metadata_cbt_header( @pytest.mark.parametrize( "metadata_field", [ - b"cluster", + b"bad-input", "cluster zone", # expect bytes ], ) @@ -389,30 +385,6 @@ def test_add_response_metadata_server_timing_header( assert metric.cluster_id is None assert metric.zone is None - def test_attempt_first_response(self): - cls = type(self._make_one(mock.Mock())) - with mock.patch.object(cls, "_handle_error") as mock_handle_error: - metric = self._make_one(mock.Mock()) - metric.start_attempt() - metric.active_attempt.start_time_ns = 0 - metric.attempt_first_response() - got_latency_ns = metric.active_attempt.first_response_latency_ns - # latency should be equal to current time - assert abs(got_latency_ns - time.monotonic_ns()) < 1e6 # 1ms - # should remain in ACTIVE_ATTEMPT state after completing - assert metric.state == State.ACTIVE_ATTEMPT - # no errors encountered - assert mock_handle_error.call_count == 0 - # calling it again should cause an error - metric.attempt_first_response() - assert mock_handle_error.call_count == 1 - assert ( - mock_handle_error.call_args[0][0] - == "Attempt already received first response" - ) - # value should not be changed - assert metric.active_attempt.first_response_latency_ns == got_latency_ns - def test_end_attempt_with_status(self): """ ending the attempt should: @@ -420,7 +392,6 @@ def test_end_attempt_with_status(self): - reset active_attempt to None - update state """ - expected_latency_ns = 9 expected_start_time = 1 expected_status = object() expected_gfe_latency_ns = 5 @@ -434,7 +405,6 @@ def test_end_attempt_with_status(self): metric.start_attempt() metric.active_attempt.start_time_ns = expected_start_time metric.active_attempt.gfe_latency_ns = expected_gfe_latency_ns - metric.active_attempt.first_response_latency_ns = expected_latency_ns metric.active_attempt.application_blocking_time_ns = expected_app_blocking metric.active_attempt.backoff_before_attempt_ns = expected_backoff metric.active_attempt.grpc_throttling_time_ns = expected_grpc_throttle @@ -443,7 +413,6 @@ def test_end_attempt_with_status(self): got_attempt = metric.completed_attempts[0] expected_duration = time.monotonic_ns() - expected_start_time assert abs(got_attempt.duration_ns - expected_duration) < 10e6 # within 10ms - assert got_attempt.first_response_latency_ns == expected_latency_ns assert got_attempt.grpc_throttling_time_ns == expected_grpc_throttle assert got_attempt.end_status == expected_status assert got_attempt.gfe_latency_ns == expected_gfe_latency_ns @@ -479,10 +448,10 @@ def test_end_with_status(self): from google.cloud.bigtable.data._metrics.data_model import ActiveAttemptMetric expected_attempt_start_time = 0 - expected_attempt_first_response_latency_ns = 9 expected_attempt_gfe_latency_ns = 5 expected_flow_time = 16 + expected_first_response_latency_ns = 9 expected_status = object() expected_type = object() expected_start_time = 1 @@ -498,9 +467,9 @@ def test_end_with_status(self): metric.zone = expected_zone metric.is_streaming = is_streaming metric.flow_throttling_time_ns = expected_flow_time + metric.first_response_latency_ns = expected_first_response_latency_ns attempt = ActiveAttemptMetric( start_time_ns=expected_attempt_start_time, - first_response_latency_ns=expected_attempt_first_response_latency_ns, gfe_latency_ns=expected_attempt_gfe_latency_ns, ) metric.active_attempt = attempt @@ -525,13 +494,13 @@ def test_end_with_status(self): assert called_with.zone == expected_zone assert called_with.is_streaming == is_streaming assert called_with.flow_throttling_time_ns == expected_flow_time + assert ( + called_with.first_response_latency_ns + == expected_first_response_latency_ns + ) # check the attempt assert len(called_with.completed_attempts) == 1 final_attempt = called_with.completed_attempts[0] - assert ( - final_attempt.first_response_latency_ns - == expected_attempt_first_response_latency_ns - ) assert final_attempt.gfe_latency_ns == expected_attempt_gfe_latency_ns assert final_attempt.end_status == expected_status expected_duration = time.monotonic_ns() - expected_attempt_start_time @@ -590,32 +559,6 @@ def test_end_on_empty_operation(self): assert final_op.final_status == StatusCode.OK assert final_op.completed_attempts == [] - def test_build_wrapped_predicate(self): - """ - predicate generated by object should terminate attempt or operation - based on passed in predicate - """ - input_exc = ValueError("test") - cls = type(self._make_one(object())) - # ensure predicate is called with the exception - mock_predicate = mock.Mock() - cls.build_wrapped_predicate(mock.Mock(), mock_predicate)(input_exc) - assert mock_predicate.call_count == 1 - assert mock_predicate.call_args[0][0] == input_exc - assert len(mock_predicate.call_args[0]) == 1 - # if predicate is true, end the attempt - mock_instance = mock.Mock() - cls.build_wrapped_predicate(mock_instance, lambda x: True)(input_exc) - assert mock_instance.end_attempt_with_status.call_count == 1 - assert mock_instance.end_attempt_with_status.call_args[0][0] == input_exc - assert len(mock_instance.end_attempt_with_status.call_args[0]) == 1 - # if predicate is false, end the operation - mock_instance = mock.Mock() - cls.build_wrapped_predicate(mock_instance, lambda x: False)(input_exc) - assert mock_instance.end_with_status.call_count == 1 - assert mock_instance.end_with_status.call_args[0][0] == input_exc - assert len(mock_instance.end_with_status.call_args[0]) == 1 - def test__exc_to_status(self): """ Should return grpc_status_code if grpc error, otherwise UNKNOWN @@ -688,16 +631,15 @@ def test__handle_error(self): assert len(logger_mock.warning.call_args[0]) == 1 @pytest.mark.asyncio - async def test_async_context_manager(self): + async def test_context_manager(self): """ Should implement context manager protocol """ metric = self._make_one(object()) with mock.patch.object(metric, "end_with_success") as end_with_success_mock: end_with_success_mock.side_effect = lambda: metric.end_with_status(object()) - async with metric as context: - assert isinstance(context, type(metric)._AsyncContextManager) - assert context.operation == metric + with metric as context: + assert context == metric # inside context manager, still active assert end_with_success_mock.call_count == 0 assert metric.state == State.CREATED @@ -706,7 +648,7 @@ async def test_async_context_manager(self): assert metric.state == State.COMPLETED @pytest.mark.asyncio - async def test_async_context_manager_exception(self): + async def test_context_manager_exception(self): """ Exception within context manager causes end_with_status to be called with error """ @@ -714,9 +656,7 @@ async def test_async_context_manager_exception(self): metric = self._make_one(object()) with mock.patch.object(metric, "end_with_status") as end_with_status_mock: try: - async with metric as context: - assert isinstance(context, type(metric)._AsyncContextManager) - assert context.operation == metric + with metric: # inside context manager, still active assert end_with_status_mock.call_count == 0 assert metric.state == State.CREATED @@ -726,155 +666,3 @@ async def test_async_context_manager_exception(self): # outside context manager, should be ended assert end_with_status_mock.call_count == 1 assert end_with_status_mock.call_args[0][0] == expected_exc - assert len(end_with_status_mock.call_args[0]) == 1 - - @pytest.mark.asyncio - async def test_metadata_passthrough(self): - """ - add_response_metadata in context manager should defer to wrapped operation - """ - inner_result = object() - fake_metadata = object() - - metric = self._make_one(mock.Mock()) - with mock.patch.object(metric, "add_response_metadata") as mock_add_metadata: - mock_add_metadata.return_value = inner_result - async with metric as context: - result = context.add_response_metadata(fake_metadata) - assert result == inner_result - assert mock_add_metadata.call_count == 1 - assert mock_add_metadata.call_args[0][0] == fake_metadata - assert len(mock_add_metadata.call_args[0]) == 1 - - @pytest.mark.asyncio - async def test_wrap_attempt_fn_success(self): - """ - Context manager's wrap_attempt_fn should wrap an arbitrary function - in operation instrumentation - - Test successful call - - should return the result of the wrapped function - - should call end_with_success - """ - from grpc import StatusCode - - metric = self._make_one(object()) - async with metric as context: - mock_call = mock.AsyncMock() - mock_args = (1, 2, 3) - mock_kwargs = {"a": 1, "b": 2} - inner_fn = lambda *args, **kwargs: mock_call(*args, **kwargs) # noqa - wrapped_fn = context.wrap_attempt_fn(inner_fn, extract_call_metadata=False) - # make the wrapped call - result = await wrapped_fn(*mock_args, **mock_kwargs) - assert result == mock_call.return_value - assert mock_call.call_count == 1 - assert mock_call.call_args[0] == mock_args - assert mock_call.call_args[1] == mock_kwargs - assert mock_call.await_count == 1 - # operation should be still in progress after wrapped fn - # let context manager close it, in case we need to add metadata, etc - assert metric.state == State.ACTIVE_ATTEMPT - # make sure the operation is complete after exiting context manager - assert metric.state == State.COMPLETED - assert len(metric.completed_attempts) == 1 - assert metric.completed_attempts[0].end_status == StatusCode.OK - - @pytest.mark.asyncio - async def test_wrap_attempt_fn_failed_extract_call_metadata(self): - """ - When extract_call_metadata is True, should call add_response_metadata - on operation with output of wrapped function, even if failed - """ - mock_call = mock.AsyncMock() - mock_call.trailing_metadata.return_value = 3 - mock_call.initial_metadata.return_value = 4 - inner_fn = lambda *args, **kwargs: mock_call # noqa - metric = self._make_one(object()) - async with metric as context: - wrapped_fn = context.wrap_attempt_fn(inner_fn, extract_call_metadata=True) - with mock.patch.object( - metric, "add_response_metadata" - ) as mock_add_metadata: - # make the wrapped call. expect exception when awaiting on mock_call - with pytest.raises(TypeError): - await wrapped_fn() - assert mock_add_metadata.call_count == 1 - assert mock_call.trailing_metadata.call_count == 1 - assert mock_call.initial_metadata.call_count == 1 - assert mock_add_metadata.call_args[0][0] == 3 + 4 - - @pytest.mark.asyncio - async def test_wrap_attempt_fn_failed_extract_call_metadata_no_mock(self): - """ - Make sure the metadata is accessible after a failed attempt - """ - import grpc - - mock_call = mock.AsyncMock() - mock_call.trailing_metadata.return_value = grpc.aio.Metadata() - mock_call.initial_metadata.return_value = grpc.aio.Metadata( - ("server-timing", "gfet4t7; dur=5000") - ) - inner_fn = lambda *args, **kwargs: mock_call # noqa - metric = self._make_one(object()) - async with metric as context: - wrapped_fn = context.wrap_attempt_fn(inner_fn, extract_call_metadata=True) - with pytest.raises(TypeError): - await wrapped_fn() - assert metric.active_attempt is None - assert len(metric.completed_attempts) == 1 - assert metric.completed_attempts[0].gfe_latency_ns == 5000e6 # ms to ns - - @pytest.mark.asyncio - async def test_wrap_attempt_fn_failed_attempt(self): - """ - failed attempts should call operation.end_attempt with error - """ - from grpc import StatusCode - - metric = self._make_one(object()) - async with metric as context: - wrapped_fn = context.wrap_attempt_fn( - mock.Mock(), extract_call_metadata=False - ) - # make the wrapped call. expect type error when awaiting response of mock - with pytest.raises(TypeError): - await wrapped_fn() - # should have one failed attempt, but operation still in progress - assert len(metric.completed_attempts) == 1 - assert metric.state == State.BETWEEN_ATTEMPTS - assert metric.active_attempt is None - # unknown status from type error - assert metric.completed_attempts[0].end_status == StatusCode.UNKNOWN - # make sure operation is closed on end - assert metric.state == State.COMPLETED - - @pytest.mark.asyncio - async def test_wrap_attempt_fn_with_retry(self): - """ - wrap_attampt_fn is meant to be used with retry object. Test using them together - """ - from grpc import StatusCode - from google.api_core.retry import AsyncRetry - from google.api_core.exceptions import RetryError - - metric = self._make_one(object()) - with pytest.raises(RetryError): - # should eventually fail due to timeout - async with metric as context: - always_retry = lambda x: True # noqa - retry_obj = AsyncRetry( - predicate=always_retry, timeout=0.05, maximum=0.001 - ) - # mock.Mock will fail on await - double_wrapped_fn = retry_obj( - context.wrap_attempt_fn(mock.Mock(), extract_call_metadata=False) - ) - await double_wrapped_fn() - # make sure operation ended with expected state - assert metric.state == State.COMPLETED - # we expect > 30 retries in 0.05 seconds - assert len(metric.completed_attempts) > 5 - # unknown error due to TyperError - assert metric.completed_attempts[-1].end_status == StatusCode.UNKNOWN diff --git a/tests/unit/data/_metrics/test_metrics_controller.py b/tests/unit/data/_metrics/test_metrics_controller.py index 12cd32c92..701af737b 100644 --- a/tests/unit/data/_metrics/test_metrics_controller.py +++ b/tests/unit/data/_metrics/test_metrics_controller.py @@ -21,15 +21,19 @@ def _make_one(self, *args, **kwargs): BigtableClientSideMetricsController, ) + # add mock interceptor if called with no arguments + if not args and "interceptor" not in kwargs: + args = [mock.Mock()] + return BigtableClientSideMetricsController(*args, **kwargs) def test_ctor_defaults(self): """ should create instance with GCP Exporter handler by default """ - instance = self._make_one( - project_id="p", instance_id="i", table_id="t", app_profile_id="a" - ) + expected_interceptor = object() + instance = self._make_one(expected_interceptor) + assert instance.interceptor == expected_interceptor assert len(instance.handlers) == 0 def ctor_custom_handlers(self): @@ -37,7 +41,9 @@ def ctor_custom_handlers(self): if handlers are passed to init, use those instead """ custom_handler = object() - controller = self._make_one(handlers=[custom_handler]) + custom_interceptor = object() + controller = self._make_one(custom_interceptor, handlers=[custom_handler]) + assert controller.interceptor == custom_interceptor assert len(controller.handlers) == 1 assert controller.handlers[0] is custom_handler @@ -88,11 +94,20 @@ def test_create_operation(self): assert len(op.handlers) == 1 assert op.handlers[0] is handler - def test_create_operation_multiple_handlers(self): - orig_handler = object() - new_handler = object() - controller = self._make_one(handlers=[orig_handler]) - op = controller.create_operation(object(), handlers=[new_handler]) - assert len(op.handlers) == 2 - assert orig_handler in op.handlers - assert new_handler in op.handlers + def test_create_operation_registers_interceptor(self): + """ + creating an operation should link the operation with the controller's interceptor, + and add the interceptor as a handler to the operation + """ + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor, + ) + + custom_handler = object() + controller = self._make_one( + BigtableMetricsInterceptor(), handlers=[custom_handler] + ) + op = controller.create_operation(object()) + assert custom_handler in op.handlers + assert op.uuid in controller.interceptor.operation_map + assert controller.interceptor.operation_map[op.uuid] == op From 5ae7accc391d159a41c8cd86633334c745ba6317 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Aug 2025 16:42:41 -0700 Subject: [PATCH 34/60] added set_next to TrackedBackoffGenerator --- google/cloud/bigtable/data/_helpers.py | 29 ++++++++++++++++- tests/unit/data/test__helpers.py | 44 ++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index ba64ffab4..e848ebc6f 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -264,18 +264,45 @@ def __init__(self, initial=0.01, maximum=60, multiplier=2): self.subgenerator = exponential_sleep_generator( initial=initial, maximum=maximum, multiplier=multiplier ) + self._next_override: float | None = None def __iter__(self): return self + def set_next(self, next_value: float): + """ + Set the next backoff value, instead of generating one from subgenerator. + After the value is yielded, it will go back to using self.subgenerator. + + If set_next is called twice before the next() is called, only the latest + value will be used and others discarded + + Args: + next_value: the upcomming value to yield when next() is called + Raises: + ValueError: if next_value is negative + """ + if next_value < 0: + raise ValueError("backoff value cannot be less than 0") + self._next_override = next_value + def __next__(self) -> float: - next_backoff = next(self.subgenerator) + if self._next_override is not None: + next_backoff = self._next_override + self._next_override = None + else: + next_backoff = next(self.subgenerator) self.history.append(next_backoff) return next_backoff def get_attempt_backoff(self, attempt_idx) -> float: """ returns the backoff time for a specific attempt index, starting at 0. + + Args: + attempt_idx: the index of the attempt to return backoff for + Raises: + IndexError: if attempt_idx is negative, or not in history """ if attempt_idx < 0: raise IndexError("received negative attempt number") diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index fda0b7686..c8540024d 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -317,3 +317,47 @@ def test_get_attempt_backoff_out_of_bounds(self): generator.get_attempt_backoff(2) with pytest.raises(IndexError): generator.get_attempt_backoff(-3) + + def test_set_next_full_set(self): + """ + try always using set_next to populate generator + """ + generator = _helpers.TrackedBackoffGenerator() + for idx, val in enumerate(range(100, 0, -1)): + generator.set_next(val) + got = next(generator) + assert got == val + assert generator.get_attempt_backoff(idx) == val + + def test_set_next_negative_value(self): + generator = _helpers.TrackedBackoffGenerator() + with pytest.raises(ValueError): + generator.set_next(-1) + + @mock.patch("random.uniform", side_effect=lambda a, b: b) + def test_interleaved_set_next(self, mock_uniform): + import itertools + + generator = _helpers.TrackedBackoffGenerator( + initial=1, multiplier=2, maximum=128 + ) + # values we expect generator to create + expected_values = [2**i for i in range(8)] + # values we will insert + inserted_values = [9, 61, 0, 4, 33, 12, 18, 2] + for idx in range(8): + assert next(generator) == expected_values[idx] + generator.set_next(inserted_values[idx]) + assert next(generator) == inserted_values[idx] + # check to make sure history is as we expect + generator.history = itertools.chain.from_iterable( + zip(expected_values, inserted_values) + ) + + @mock.patch("random.uniform", side_effect=lambda a, b: b) + def test_set_next_replacement(self, mock_uniform): + generator = _helpers.TrackedBackoffGenerator(initial=1) + generator.set_next(99) + generator.set_next(88) + assert next(generator) == 88 + assert next(generator) == 1 From cb3229694b36476b7051ab746936acf6cc2a79e6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Aug 2025 17:17:00 -0700 Subject: [PATCH 35/60] added assertions to test_client --- tests/unit/data/_async/test_client.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 4c2e0c628..24dfb4430 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1160,6 +1160,7 @@ def _make_one( @CrossSync.pytest async def test_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + from google.cloud.bigtable.data._metrics import BigtableClientSideMetricsController expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -1201,6 +1202,8 @@ async def test_ctor(self): instance_key = _WarmedInstanceKey(table.instance_name, table.app_profile_id) assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(table)} + assert isinstance(table._metrics, BigtableClientSideMetricsController) + assert table._metrics.interceptor == client._metrics_interceptor assert table.default_operation_timeout == expected_operation_timeout assert table.default_attempt_timeout == expected_attempt_timeout assert ( @@ -1490,6 +1493,7 @@ def _make_one( @CrossSync.pytest async def test_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + from google.cloud.bigtable.data._metrics import BigtableClientSideMetricsController expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -1538,6 +1542,8 @@ async def test_ctor(self): instance_key = _WarmedInstanceKey(view.instance_name, view.app_profile_id) assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(view)} + assert isinstance(view._metrics, BigtableClientSideMetricsController) + assert view._metrics.interceptor == client._metrics_interceptor assert view.default_operation_timeout == expected_operation_timeout assert view.default_attempt_timeout == expected_attempt_timeout assert ( From f07e76594c701fb9e8e12574e4059a7b9e73859a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 10:42:56 -0700 Subject: [PATCH 36/60] added new test metrics interceptor file --- .../data/_async/test_metrics_interceptor.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tests/unit/data/_async/test_metrics_interceptor.py diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py new file mode 100644 index 000000000..ec61cfb35 --- /dev/null +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from google.cloud.bigtable.data._cross_sync import CrossSync + + +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_metrics_interceptor" + +@CrossSync.convert_class +class TestMetricsInterceptor: \ No newline at end of file From a34c01e302c1ea8007e42a59ab722ae3782507b4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 13:49:32 -0700 Subject: [PATCH 37/60] first round of tests --- .../data/_async/metrics_interceptor.py | 24 +- .../data/_async/test_metrics_interceptor.py | 276 +++++++++++++++++- 2 files changed, 283 insertions(+), 17 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index e5777ea0f..817e0d825 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -114,20 +114,17 @@ async def intercept_unary_unary( raise finally: if call is not None: - metadata = ( - await call.trailing_metadata() + await call.initial_metadata() - ) + metadata = (await call.trailing_metadata() or []) + (await call.initial_metadata() or []) operation.add_response_metadata(metadata) - if encountered_exc is not None: - # end attempt. If it succeeded, let higher levels decide when to end operation - operation.end_attempt_with_status(encountered_exc) + if encountered_exc is not None: + # end attempt. If it succeeded, let higher levels decide when to end operation + operation.end_attempt_with_status(encountered_exc) @CrossSync.convert @_with_operation_from_metadata async def intercept_unary_stream( self, operation, continuation, client_call_details, request ): - # TODO: benchmark async def response_wrapper(call): has_first_response = operation.first_response_latency is not None encountered_exc = None @@ -145,12 +142,11 @@ async def response_wrapper(call): encountered_exc = e raise finally: - metadata = ( - await call.trailing_metadata() + await call.initial_metadata() - ) - operation.add_response_metadata(metadata) - if encountered_exc is not None: - # end attempt. If it succeeded, let higher levels decide when to end operation - operation.end_attempt_with_status(encountered_exc) + if call is not None: + metadata = (await call.trailing_metadata() or []) + (await call.initial_metadata() or []) + operation.add_response_metadata(metadata) + if encountered_exc is not None: + # end attempt. If it succeeded, let higher levels decide when to end operation + operation.end_attempt_with_status(encountered_exc) return response_wrapper(await continuation(client_call_details, request)) diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index ec61cfb35..55b66bfbd 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -10,12 +10,282 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License +# limitations under the License. + +import pytest +import asyncio from google.cloud.bigtable.data._cross_sync import CrossSync +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore + __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_metrics_interceptor" -@CrossSync.convert_class -class TestMetricsInterceptor: \ No newline at end of file + +class _AsyncIterator: + """Helper class to wrap an iterator or async generator in an async iterator""" + + def __init__(self, iterable): + if hasattr(iterable, "__anext__"): + self._iterator = iterable + else: + self._iterator = iter(iterable) + + def __aiter__(self): + return self + + async def __anext__(self): + if hasattr(self._iterator, "__anext__"): + return await self._iterator.__anext__() + try: + return next(self._iterator) + except StopIteration: + raise StopAsyncIteration + + +@CrossSync.convert_class(sync_name="TestMetricsInterceptor") +class TestMetricsInterceptorAsync: + @staticmethod + @CrossSync.convert + def _get_target_class(): + from google.cloud.bigtable.data._async import metrics_interceptor + + return metrics_interceptor.AsyncBigtableMetricsInterceptor + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one() + assert instance.operation_map == {} + + def test_register_operation(self): + """ + adding a new operation should register it in operation_map + """ + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + from google.cloud.bigtable.data._metrics.data_model import OperationType + + instance = self._make_one() + op = ActiveOperationMetric(OperationType.READ_ROWS) + instance.register_operation(op) + assert instance.operation_map[op.uuid] == op + assert instance in op.handlers + + def test_on_operation_comple_mock(self): + """ + completing or cancelling an operation should call on_operation_complete on interceptor + """ + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + from google.cloud.bigtable.data._metrics.data_model import OperationType + + instance = self._make_one() + instance.on_operation_complete = mock.Mock() + op = ActiveOperationMetric(OperationType.READ_ROWS) + instance.register_operation(op) + op.end_with_success() + assert instance.on_operation_complete.call_count == 1 + op.cancel() + assert instance.on_operation_complete.call_count == 2 + + def test_on_operation_complete(self): + """ + completing an operation should remove it from the operation map + """ + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + from google.cloud.bigtable.data._metrics.data_model import OperationType + + instance = self._make_one() + op = ActiveOperationMetric(OperationType.READ_ROWS) + instance.register_operation(op) + op.end_with_success() + instance.on_operation_complete(op) + assert op.uuid not in instance.operation_map + + def test_on_operation_cancelled(self): + """ + completing an operation should remove it from the operation map + """ + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + from google.cloud.bigtable.data._metrics.data_model import OperationType + + instance = self._make_one() + op = ActiveOperationMetric(OperationType.READ_ROWS) + instance.register_operation(op) + op.cancel() + assert op.uuid not in instance.operation_map + + @CrossSync.pytest + async def test_unary_unary_interceptor_op_not_found(self): + """Test that interceptor call cuntinuation if op is not found""" + instance = self._make_one() + continuation = CrossSync.Mock() + details = mock.Mock() + details.metadata = [] + request = mock.Mock() + await instance.intercept_unary_unary(continuation, details, request) + continuation.assert_called_once_with(details, request) + + @CrossSync.pytest + async def test_unary_unary_interceptor_success(self): + """Test that interceptor handles successful unary-unary calls""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + continuation = CrossSync.Mock() + call = continuation.return_value + call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + result = await instance.intercept_unary_unary(continuation, details, request) + assert result == call + continuation.assert_called_once_with(details, request) + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_not_called() + + @CrossSync.pytest + async def test_unary_unary_interceptor_failure(self): + """Test that interceptor handles failed unary-unary calls""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + exc = ValueError("test") + continuation = CrossSync.Mock(side_effect=exc) + call = continuation.return_value + call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(ValueError) as e: + await instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_called_once_with(exc) + + @CrossSync.pytest + async def test_unary_stream_interceptor_op_not_found(self): + """Test that interceptor calls continuation if op is not found""" + instance = self._make_one() + continuation = CrossSync.Mock() + details = mock.Mock() + details.metadata = [] + request = mock.Mock() + await instance.intercept_unary_stream(continuation, details, request) + continuation.assert_called_once_with(details, request) + + @CrossSync.pytest + async def test_unary_stream_interceptor_success(self): + """Test that interceptor handles successful unary-stream calls""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + + continuation = CrossSync.Mock() + call = continuation.return_value + call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + wrapper = await instance.intercept_unary_stream(continuation, details, request) + results = [val async for val in wrapper] + assert results == [1, 2] + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_not_called() + + @CrossSync.pytest + async def test_unary_stream_interceptor_failure_mid_stream(self): + """Test that interceptor handles failures mid-stream""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + exc = ValueError("test") + + continuation = CrossSync.Mock() + call = continuation.return_value + async def mock_generator(): + yield 1 + raise exc + call.__aiter__ = mock.Mock(return_value=_AsyncIterator(mock_generator())) + call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + wrapper = await instance.intercept_unary_stream(continuation, details, request) + with pytest.raises(ValueError) as e: + [val async for val in wrapper] + assert e.value == exc + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_called_once_with(exc) + + @CrossSync.pytest + async def test_unary_stream_interceptor_failure_start_stream(self): + """Test that interceptor handles failures at start of stream""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + exc = ValueError("test") + + continuation = CrossSync.Mock() + continuation.side_effect = exc + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(ValueError) as e: + await instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_called_once_with(exc) \ No newline at end of file From 84f61ee659a5cb891f61b27a28cd4581bf4eb962 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 14:18:31 -0700 Subject: [PATCH 38/60] added metadata capture for failed rpcs --- .../data/_async/metrics_interceptor.py | 31 ++++- .../data/_async/test_metrics_interceptor.py | 124 +++++++++++++++++- 2 files changed, 147 insertions(+), 8 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 817e0d825..28aa89095 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -15,6 +15,7 @@ import time from functools import wraps +from grpc import RpcError from google.cloud.bigtable.data._metrics.data_model import ( OPERATION_INTERCEPTOR_METADATA_KEY, ) @@ -105,16 +106,24 @@ async def intercept_unary_unary( self, operation, continuation, client_call_details, request ): encountered_exc: Exception | None = None - call = None + metadata = None try: call = await continuation(client_call_details, request) + metadata = (await call.trailing_metadata() or []) + (await call.initial_metadata() or []) return call + except RpcError as rpc_error: + # attempt extracting metadata from error + try: + metadata = (await rpc_error.trailing_metadata() or []) + (await rpc_error.initial_metadata() or []) + except Exception: + pass + encountered_exc = rpc_error + raise rpc_error except Exception as e: encountered_exc = e raise finally: - if call is not None: - metadata = (await call.trailing_metadata() or []) + (await call.initial_metadata() or []) + if metadata is not None: operation.add_response_metadata(metadata) if encountered_exc is not None: # end attempt. If it succeeded, let higher levels decide when to end operation @@ -138,6 +147,7 @@ async def response_wrapper(call): has_first_response = True yield response + except Exception as e: encountered_exc = e raise @@ -149,4 +159,17 @@ async def response_wrapper(call): # end attempt. If it succeeded, let higher levels decide when to end operation operation.end_attempt_with_status(encountered_exc) - return response_wrapper(await continuation(client_call_details, request)) + try: + return response_wrapper(await continuation(client_call_details, request)) + except RpcError as rpc_error: + # attempt extracting metadata from error + try: + metadata = (await rpc_error.trailing_metadata() or []) + (await rpc_error.initial_metadata() or []) + operation.add_response_metadata(metadata) + except Exception: + pass + operation.end_attempt_with_status(rpc_error) + raise rpc_error + except Exception as e: + operation.end_attempt_with_status(e) + raise \ No newline at end of file diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index 55b66bfbd..2674e488a 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -14,6 +14,7 @@ import pytest import asyncio +from grpc import RpcError from google.cloud.bigtable.data._cross_sync import CrossSync @@ -158,7 +159,61 @@ async def test_unary_unary_interceptor_success(self): @CrossSync.pytest async def test_unary_unary_interceptor_failure(self): - """Test that interceptor handles failed unary-unary calls""" + """test a failed RpcError with metadata""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + exc = RpcError("test") + exc.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) + exc.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) + continuation = CrossSync.Mock(side_effect=exc) + call = continuation.return_value + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(RpcError) as e: + await instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_called_once_with(exc) + + @CrossSync.pytest + async def test_unary_unary_interceptor_failure_no_metadata(self): + """test with RpcError without without metadata attached""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + exc = RpcError("test") + continuation = CrossSync.Mock(side_effect=exc) + call = continuation.return_value + call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(RpcError) as e: + await instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + op.add_response_metadata.assert_not_called() + op.end_attempt_with_status.assert_called_once_with(exc) + + @CrossSync.pytest + async def test_unary_unary_interceptor_failure_generic(self): + """test generic exception""" from google.cloud.bigtable.data._metrics.data_model import ( OPERATION_INTERCEPTOR_METADATA_KEY, ) @@ -180,9 +235,10 @@ async def test_unary_unary_interceptor_failure(self): await instance.intercept_unary_unary(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.add_response_metadata.assert_not_called() op.end_attempt_with_status.assert_called_once_with(exc) + @CrossSync.pytest async def test_unary_stream_interceptor_op_not_found(self): """Test that interceptor calls continuation if op is not found""" @@ -263,7 +319,67 @@ async def mock_generator(): @CrossSync.pytest async def test_unary_stream_interceptor_failure_start_stream(self): - """Test that interceptor handles failures at start of stream""" + """Test that interceptor handles failures at start of stream with RpcError with metadata""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + exc = RpcError("test") + exc.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) + exc.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) + + continuation = CrossSync.Mock() + continuation.side_effect = exc + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(RpcError) as e: + await instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_called_once_with(exc) + + @CrossSync.pytest + async def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): + """Test that interceptor handles failures at start of stream with RpcError with no metadata""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = 1 # ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + exc = RpcError("test") + + continuation = CrossSync.Mock() + continuation.side_effect = exc + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(RpcError) as e: + await instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_not_called() + op.end_attempt_with_status.assert_called_once_with(exc) + + @CrossSync.pytest + async def test_unary_stream_interceptor_failure_start_stream_generic(self): + """Test that interceptor handles failures at start of stream with generic exception""" from google.cloud.bigtable.data._metrics.data_model import ( OPERATION_INTERCEPTOR_METADATA_KEY, ) @@ -287,5 +403,5 @@ async def test_unary_stream_interceptor_failure_start_stream(self): assert e.value == exc continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.add_response_metadata.assert_not_called() op.end_attempt_with_status.assert_called_once_with(exc) \ No newline at end of file From d4ae6379ac6e9c95a953d59f1a2f7ec9ae9ac7cb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 14:35:31 -0700 Subject: [PATCH 39/60] added test for starting attempts --- .../data/_async/metrics_interceptor.py | 2 +- .../data/_async/test_metrics_interceptor.py | 69 ++++++++++++++++--- 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 28aa89095..88a9f27fe 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -55,7 +55,7 @@ def wrapper(self, continuation, client_call_details, request): operation: "ActiveOperationMetric" = self.operation_map.get(key) if operation: # start a new attempt if not started - if operation.state != OperationState.ACTIVE_ATTEMPT: + if operation.state == OperationState.CREATED or operation.state == OperationState.BETWEEN_ATTEMPTS: operation.start_attempt() # wrap continuation in logic to process the operation return func(self, operation, continuation, client_call_details, request) diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index 2674e488a..58d0ae34b 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -16,6 +16,7 @@ import asyncio from grpc import RpcError +from google.cloud.bigtable.data._metrics.data_model import OperationState from google.cloud.bigtable.data._cross_sync import CrossSync # try/except added for compatibility with python < 3.8 @@ -142,7 +143,7 @@ async def test_unary_unary_interceptor_success(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT instance.operation_map[op.uuid] = op continuation = CrossSync.Mock() call = continuation.return_value @@ -167,7 +168,7 @@ async def test_unary_unary_interceptor_failure(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT instance.operation_map[op.uuid] = op exc = RpcError("test") exc.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) @@ -194,7 +195,7 @@ async def test_unary_unary_interceptor_failure_no_metadata(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT instance.operation_map[op.uuid] = op exc = RpcError("test") continuation = CrossSync.Mock(side_effect=exc) @@ -221,7 +222,7 @@ async def test_unary_unary_interceptor_failure_generic(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT instance.operation_map[op.uuid] = op exc = ValueError("test") continuation = CrossSync.Mock(side_effect=exc) @@ -260,7 +261,7 @@ async def test_unary_stream_interceptor_success(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None instance.operation_map[op.uuid] = op @@ -291,7 +292,7 @@ async def test_unary_stream_interceptor_failure_mid_stream(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None instance.operation_map[op.uuid] = op @@ -327,7 +328,7 @@ async def test_unary_stream_interceptor_failure_start_stream(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None instance.operation_map[op.uuid] = op @@ -358,7 +359,7 @@ async def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None instance.operation_map[op.uuid] = op @@ -387,7 +388,7 @@ async def test_unary_stream_interceptor_failure_start_stream_generic(self): instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" - op.state = 1 # ACTIVE_ATTEMPT + op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None instance.operation_map[op.uuid] = op @@ -404,4 +405,52 @@ async def test_unary_stream_interceptor_failure_start_stream_generic(self): continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) \ No newline at end of file + op.end_attempt_with_status.assert_called_once_with(exc) + + @CrossSync.pytest + @pytest.mark.parametrize( + "initial_state", [OperationState.CREATED, OperationState.BETWEEN_ATTEMPTS] + ) + async def test_unary_unary_interceptor_start_operation(self, initial_state): + """if called with a newly created operation, it should be started""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = initial_state + instance.operation_map[op.uuid] = op + continuation = CrossSync.Mock() + call = continuation.return_value + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + await instance.intercept_unary_unary(continuation, details, request) + op.start_attempt.assert_called_once() + + @CrossSync.pytest + @pytest.mark.parametrize( + "initial_state", [OperationState.CREATED, OperationState.BETWEEN_ATTEMPTS] + ) + async def test_unary_stream_interceptor_start_operation(self, initial_state): + """if called with a newly created operation, it should be started""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = initial_state + instance.operation_map[op.uuid] = op + + continuation = CrossSync.Mock() + call = continuation.return_value + call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + await instance.intercept_unary_stream(continuation, details, request) + op.start_attempt.assert_called_once() \ No newline at end of file From 1fbcadd4c856ef9b7477639a4df1fe7b9adad31b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 14:41:56 -0700 Subject: [PATCH 40/60] added sync tests --- .../data/_sync_autogen/metrics_interceptor.py | 64 ++- .../data/_async/test_metrics_interceptor.py | 11 +- .../_sync_autogen/test_metrics_interceptor.py | 426 ++++++++++++++++++ 3 files changed, 485 insertions(+), 16 deletions(-) create mode 100644 tests/unit/data/_sync_autogen/test_metrics_interceptor.py diff --git a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py index 21dd6752a..9a14a2cfa 100644 --- a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py @@ -15,7 +15,9 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations +import time from functools import wraps +from grpc import RpcError from google.cloud.bigtable.data._metrics.data_model import ( OPERATION_INTERCEPTOR_METADATA_KEY, ) @@ -42,7 +44,10 @@ def wrapper(self, continuation, client_call_details, request): ) operation: "ActiveOperationMetric" = self.operation_map.get(key) if operation: - if operation.state != OperationState.ACTIVE_ATTEMPT: + if ( + operation.state == OperationState.CREATED + or operation.state == OperationState.BETWEEN_ATTEMPTS + ): operation.start_attempt() return func(self, operation, continuation, client_call_details, request) else: @@ -76,7 +81,8 @@ def register_operation(self, operation): operation.handlers.append(self) def on_operation_complete(self, op): - del self.operation_map[op.uuid] + if op.uuid in self.operation_map: + del self.operation_map[op.uuid] def on_operation_cancelled(self, op): self.on_operation_complete(op) @@ -86,36 +92,70 @@ def intercept_unary_unary( self, operation, continuation, client_call_details, request ): encountered_exc: Exception | None = None - call = None + metadata = None try: call = continuation(client_call_details, request) + metadata = (call.trailing_metadata() or []) + ( + call.initial_metadata() or [] + ) return call + except RpcError as rpc_error: + try: + metadata = (rpc_error.trailing_metadata() or []) + ( + rpc_error.initial_metadata() or [] + ) + except Exception: + pass + encountered_exc = rpc_error + raise rpc_error except Exception as e: encountered_exc = e raise finally: - if call is not None: - metadata = call.trailing_metadata() + call.initial_metadata() + if metadata is not None: operation.add_response_metadata(metadata) - if encountered_exc is not None: - operation.end_attempt_with_status(encountered_exc) + if encountered_exc is not None: + operation.end_attempt_with_status(encountered_exc) @_with_operation_from_metadata def intercept_unary_stream( self, operation, continuation, client_call_details, request ): def response_wrapper(call): + has_first_response = operation.first_response_latency is not None encountered_exc = None try: for response in call: + if not has_first_response: + operation.first_response_latency_ns = ( + time.monotonic_ns() - operation.start_time_ns + ) + has_first_response = True yield response except Exception as e: encountered_exc = e raise finally: - metadata = call.trailing_metadata() + call.initial_metadata() - operation.add_response_metadata(metadata) - if encountered_exc is not None: - operation.end_attempt_with_status(encountered_exc) + if call is not None: + metadata = (call.trailing_metadata() or []) + ( + call.initial_metadata() or [] + ) + operation.add_response_metadata(metadata) + if encountered_exc is not None: + operation.end_attempt_with_status(encountered_exc) - return response_wrapper(continuation(client_call_details, request)) + try: + return response_wrapper(continuation(client_call_details, request)) + except RpcError as rpc_error: + try: + metadata = (rpc_error.trailing_metadata() or []) + ( + rpc_error.initial_metadata() or [] + ) + operation.add_response_metadata(metadata) + except Exception: + pass + operation.end_attempt_with_status(rpc_error) + raise rpc_error + except Exception as e: + operation.end_attempt_with_status(e) + raise diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index 58d0ae34b..bdb8c3f6f 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -25,6 +25,11 @@ except ImportError: # pragma: NO COVER import mock # type: ignore +if CrossSync.is_async: + from google.cloud.bigtable.data._async.metrics_interceptor import AsyncBigtableMetricsInterceptor +else: + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import BigtableMetricsInterceptor + __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_metrics_interceptor" @@ -53,11 +58,9 @@ async def __anext__(self): @CrossSync.convert_class(sync_name="TestMetricsInterceptor") class TestMetricsInterceptorAsync: @staticmethod - @CrossSync.convert + @CrossSync.convert(replace_symbols={"AsyncBigtableMetricsInterceptor": "BigtableMetricsInterceptor"}) def _get_target_class(): - from google.cloud.bigtable.data._async import metrics_interceptor - - return metrics_interceptor.AsyncBigtableMetricsInterceptor + return AsyncBigtableMetricsInterceptor def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) diff --git a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py new file mode 100644 index 000000000..932ddc03f --- /dev/null +++ b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py @@ -0,0 +1,426 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This file is automatically generated by CrossSync. Do not edit manually. + +import pytest +from grpc import RpcError +from google.cloud.bigtable.data._metrics.data_model import OperationState +from google.cloud.bigtable.data._cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock +from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor, +) + + +class _AsyncIterator: + """Helper class to wrap an iterator or async generator in an async iterator""" + + def __init__(self, iterable): + if hasattr(iterable, "__anext__"): + self._iterator = iterable + else: + self._iterator = iter(iterable) + + def __aiter__(self): + return self + + async def __anext__(self): + if hasattr(self._iterator, "__anext__"): + return await self._iterator.__anext__() + try: + return next(self._iterator) + except StopIteration: + raise StopAsyncIteration + + +class TestMetricsInterceptor: + @staticmethod + def _get_target_class(): + return BigtableMetricsInterceptor + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one() + assert instance.operation_map == {} + + def test_register_operation(self): + """adding a new operation should register it in operation_map""" + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + from google.cloud.bigtable.data._metrics.data_model import OperationType + + instance = self._make_one() + op = ActiveOperationMetric(OperationType.READ_ROWS) + instance.register_operation(op) + assert instance.operation_map[op.uuid] == op + assert instance in op.handlers + + def test_on_operation_comple_mock(self): + """completing or cancelling an operation should call on_operation_complete on interceptor""" + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + from google.cloud.bigtable.data._metrics.data_model import OperationType + + instance = self._make_one() + instance.on_operation_complete = mock.Mock() + op = ActiveOperationMetric(OperationType.READ_ROWS) + instance.register_operation(op) + op.end_with_success() + assert instance.on_operation_complete.call_count == 1 + op.cancel() + assert instance.on_operation_complete.call_count == 2 + + def test_on_operation_complete(self): + """completing an operation should remove it from the operation map""" + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + from google.cloud.bigtable.data._metrics.data_model import OperationType + + instance = self._make_one() + op = ActiveOperationMetric(OperationType.READ_ROWS) + instance.register_operation(op) + op.end_with_success() + instance.on_operation_complete(op) + assert op.uuid not in instance.operation_map + + def test_on_operation_cancelled(self): + """completing an operation should remove it from the operation map""" + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric + from google.cloud.bigtable.data._metrics.data_model import OperationType + + instance = self._make_one() + op = ActiveOperationMetric(OperationType.READ_ROWS) + instance.register_operation(op) + op.cancel() + assert op.uuid not in instance.operation_map + + def test_unary_unary_interceptor_op_not_found(self): + """Test that interceptor call cuntinuation if op is not found""" + instance = self._make_one() + continuation = CrossSync._Sync_Impl.Mock() + details = mock.Mock() + details.metadata = [] + request = mock.Mock() + instance.intercept_unary_unary(continuation, details, request) + continuation.assert_called_once_with(details, request) + + def test_unary_unary_interceptor_success(self): + """Test that interceptor handles successful unary-unary calls""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + continuation = CrossSync._Sync_Impl.Mock() + call = continuation.return_value + call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + result = instance.intercept_unary_unary(continuation, details, request) + assert result == call + continuation.assert_called_once_with(details, request) + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_not_called() + + def test_unary_unary_interceptor_failure(self): + """test a failed RpcError with metadata""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + exc = RpcError("test") + exc.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) + exc.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) + continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) + call = continuation.return_value + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(RpcError) as e: + instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_called_once_with(exc) + + def test_unary_unary_interceptor_failure_no_metadata(self): + """test with RpcError without without metadata attached""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + exc = RpcError("test") + continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) + call = continuation.return_value + call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(RpcError) as e: + instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + op.add_response_metadata.assert_not_called() + op.end_attempt_with_status.assert_called_once_with(exc) + + def test_unary_unary_interceptor_failure_generic(self): + """test generic exception""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + exc = ValueError("test") + continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) + call = continuation.return_value + call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(ValueError) as e: + instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + op.add_response_metadata.assert_not_called() + op.end_attempt_with_status.assert_called_once_with(exc) + + def test_unary_stream_interceptor_op_not_found(self): + """Test that interceptor calls continuation if op is not found""" + instance = self._make_one() + continuation = CrossSync._Sync_Impl.Mock() + details = mock.Mock() + details.metadata = [] + request = mock.Mock() + instance.intercept_unary_stream(continuation, details, request) + continuation.assert_called_once_with(details, request) + + def test_unary_stream_interceptor_success(self): + """Test that interceptor handles successful unary-stream calls""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + continuation = CrossSync._Sync_Impl.Mock() + call = continuation.return_value + call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + wrapper = instance.intercept_unary_stream(continuation, details, request) + results = [val for val in wrapper] + assert results == [1, 2] + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_not_called() + + def test_unary_stream_interceptor_failure_mid_stream(self): + """Test that interceptor handles failures mid-stream""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + exc = ValueError("test") + continuation = CrossSync._Sync_Impl.Mock() + call = continuation.return_value + + def mock_generator(): + yield 1 + raise exc + + call.__aiter__ = mock.Mock(return_value=_AsyncIterator(mock_generator())) + call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) + call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + wrapper = instance.intercept_unary_stream(continuation, details, request) + with pytest.raises(ValueError) as e: + [val for val in wrapper] + assert e.value == exc + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_called_once_with(exc) + + def test_unary_stream_interceptor_failure_start_stream(self): + """Test that interceptor handles failures at start of stream with RpcError with metadata""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + exc = RpcError("test") + exc.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) + exc.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) + continuation = CrossSync._Sync_Impl.Mock() + continuation.side_effect = exc + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(RpcError) as e: + instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.end_attempt_with_status.assert_called_once_with(exc) + + def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): + """Test that interceptor handles failures at start of stream with RpcError with no metadata""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + exc = RpcError("test") + continuation = CrossSync._Sync_Impl.Mock() + continuation.side_effect = exc + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(RpcError) as e: + instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_not_called() + op.end_attempt_with_status.assert_called_once_with(exc) + + def test_unary_stream_interceptor_failure_start_stream_generic(self): + """Test that interceptor handles failures at start of stream with generic exception""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + op.start_time_ns = 0 + op.first_response_latency = None + instance.operation_map[op.uuid] = op + exc = ValueError("test") + continuation = CrossSync._Sync_Impl.Mock() + continuation.side_effect = exc + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + with pytest.raises(ValueError) as e: + instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + assert op.first_response_latency_ns is not None + op.add_response_metadata.assert_not_called() + op.end_attempt_with_status.assert_called_once_with(exc) + + @pytest.mark.parametrize( + "initial_state", [OperationState.CREATED, OperationState.BETWEEN_ATTEMPTS] + ) + def test_unary_unary_interceptor_start_operation(self, initial_state): + """if called with a newly created operation, it should be started""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = initial_state + instance.operation_map[op.uuid] = op + continuation = CrossSync._Sync_Impl.Mock() + call = continuation.return_value + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + instance.intercept_unary_unary(continuation, details, request) + op.start_attempt.assert_called_once() + + @pytest.mark.parametrize( + "initial_state", [OperationState.CREATED, OperationState.BETWEEN_ATTEMPTS] + ) + def test_unary_stream_interceptor_start_operation(self, initial_state): + """if called with a newly created operation, it should be started""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = initial_state + instance.operation_map[op.uuid] = op + continuation = CrossSync._Sync_Impl.Mock() + call = continuation.return_value + call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + details = mock.Mock() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] + request = mock.Mock() + instance.intercept_unary_stream(continuation, details, request) + op.start_attempt.assert_called_once() From edacd04e4f90cec53325481ee01b802ddf154401 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 15:02:22 -0700 Subject: [PATCH 41/60] got tests passing --- .../data/_async/test_metrics_interceptor.py | 29 +++++++++++++---- .../_sync_autogen/test_metrics_interceptor.py | 31 +++++-------------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index bdb8c3f6f..d49ba4c95 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -34,6 +34,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_metrics_interceptor" +@CrossSync.drop class _AsyncIterator: """Helper class to wrap an iterator or async generator in an async iterator""" @@ -271,7 +272,10 @@ async def test_unary_stream_interceptor_success(self): continuation = CrossSync.Mock() call = continuation.return_value - call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + if CrossSync.is_async: + call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + else: + call.__iter__ = mock.Mock(return_value=iter([1, 2])) call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = mock.Mock() @@ -303,10 +307,16 @@ async def test_unary_stream_interceptor_failure_mid_stream(self): continuation = CrossSync.Mock() call = continuation.return_value - async def mock_generator(): - yield 1 - raise exc - call.__aiter__ = mock.Mock(return_value=_AsyncIterator(mock_generator())) + if CrossSync.is_async: + async def mock_generator(): + yield 1 + raise exc + call.__aiter__ = mock.Mock(return_value=_AsyncIterator(mock_generator())) + else: + def mock_generator(): + yield 1 + raise exc + call.__iter__ = mock.Mock(return_value=mock_generator()) call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = mock.Mock() @@ -427,6 +437,8 @@ async def test_unary_unary_interceptor_start_operation(self, initial_state): instance.operation_map[op.uuid] = op continuation = CrossSync.Mock() call = continuation.return_value + call.trailing_metadata = CrossSync.Mock(return_value=[]) + call.initial_metadata = CrossSync.Mock(return_value=[]) details = mock.Mock() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() @@ -451,7 +463,12 @@ async def test_unary_stream_interceptor_start_operation(self, initial_state): continuation = CrossSync.Mock() call = continuation.return_value - call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + if CrossSync.is_async: + call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + else: + call.__iter__ = mock.Mock(return_value=iter([1, 2])) + call.trailing_metadata = CrossSync.Mock(return_value=[]) + call.initial_metadata = CrossSync.Mock(return_value=[]) details = mock.Mock() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() diff --git a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py index 932ddc03f..b9b9bdd78 100644 --- a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py +++ b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py @@ -29,27 +29,6 @@ ) -class _AsyncIterator: - """Helper class to wrap an iterator or async generator in an async iterator""" - - def __init__(self, iterable): - if hasattr(iterable, "__anext__"): - self._iterator = iterable - else: - self._iterator = iter(iterable) - - def __aiter__(self): - return self - - async def __anext__(self): - if hasattr(self._iterator, "__anext__"): - return await self._iterator.__anext__() - try: - return next(self._iterator) - except StopIteration: - raise StopAsyncIteration - - class TestMetricsInterceptor: @staticmethod def _get_target_class(): @@ -247,7 +226,7 @@ def test_unary_stream_interceptor_success(self): instance.operation_map[op.uuid] = op continuation = CrossSync._Sync_Impl.Mock() call = continuation.return_value - call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + call.__iter__ = mock.Mock(return_value=iter([1, 2])) call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = mock.Mock() @@ -282,7 +261,7 @@ def mock_generator(): yield 1 raise exc - call.__aiter__ = mock.Mock(return_value=_AsyncIterator(mock_generator())) + call.__iter__ = mock.Mock(return_value=mock_generator()) call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = mock.Mock() @@ -396,6 +375,8 @@ def test_unary_unary_interceptor_start_operation(self, initial_state): instance.operation_map[op.uuid] = op continuation = CrossSync._Sync_Impl.Mock() call = continuation.return_value + call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) + call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) details = mock.Mock() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() @@ -418,7 +399,9 @@ def test_unary_stream_interceptor_start_operation(self, initial_state): instance.operation_map[op.uuid] = op continuation = CrossSync._Sync_Impl.Mock() call = continuation.return_value - call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + call.__iter__ = mock.Mock(return_value=iter([1, 2])) + call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) + call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) details = mock.Mock() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() From c628d21309f83652a0b7012a765d1f939b53bf65 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 15:13:35 -0700 Subject: [PATCH 42/60] removed helper class --- .../data/_async/test_metrics_interceptor.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index d49ba4c95..ccbe7869b 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -34,26 +34,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_metrics_interceptor" -@CrossSync.drop -class _AsyncIterator: - """Helper class to wrap an iterator or async generator in an async iterator""" - def __init__(self, iterable): - if hasattr(iterable, "__anext__"): - self._iterator = iterable - else: - self._iterator = iter(iterable) - - def __aiter__(self): - return self - - async def __anext__(self): - if hasattr(self._iterator, "__anext__"): - return await self._iterator.__anext__() - try: - return next(self._iterator) - except StopIteration: - raise StopAsyncIteration @CrossSync.convert_class(sync_name="TestMetricsInterceptor") @@ -273,7 +254,10 @@ async def test_unary_stream_interceptor_success(self): continuation = CrossSync.Mock() call = continuation.return_value if CrossSync.is_async: - call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + async def gen(): + yield 1 + yield 2 + call.__aiter__ = mock.Mock(return_value=gen()) else: call.__iter__ = mock.Mock(return_value=iter([1, 2])) call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) @@ -311,7 +295,7 @@ async def test_unary_stream_interceptor_failure_mid_stream(self): async def mock_generator(): yield 1 raise exc - call.__aiter__ = mock.Mock(return_value=_AsyncIterator(mock_generator())) + call.__aiter__ = mock.Mock(return_value=mock_generator()) else: def mock_generator(): yield 1 @@ -464,7 +448,10 @@ async def test_unary_stream_interceptor_start_operation(self, initial_state): continuation = CrossSync.Mock() call = continuation.return_value if CrossSync.is_async: - call.__aiter__ = mock.Mock(return_value=_AsyncIterator([1, 2])) + async def gen(): + yield 1 + yield 2 + call.__aiter__ = mock.Mock(return_value=gen()) else: call.__iter__ = mock.Mock(return_value=iter([1, 2])) call.trailing_metadata = CrossSync.Mock(return_value=[]) From 6d585ec2e5d441cf46065fe4fc21c510f9041fcf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 15:21:58 -0700 Subject: [PATCH 43/60] refactoring --- .../data/_async/test_metrics_interceptor.py | 46 +++++++------------ .../_sync_autogen/test_metrics_interceptor.py | 34 +++++++++----- 2 files changed, 39 insertions(+), 41 deletions(-) diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index ccbe7869b..ab56e9a37 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -33,8 +33,19 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_metrics_interceptor" - - +@CrossSync.convert(replace_symbols={"__aiter__": "__iter__"}) +def _make_mock_stream_call(values, exc=None): + """ + Create a mock call object that can be used for streaming calls + """ + call = CrossSync.Mock() + async def gen(): + for val in values: + yield val + if exc: + raise exc + call.__aiter__ = mock.Mock(return_value=gen()) + return call @CrossSync.convert_class(sync_name="TestMetricsInterceptor") @@ -251,15 +262,8 @@ async def test_unary_stream_interceptor_success(self): op.first_response_latency = None instance.operation_map[op.uuid] = op - continuation = CrossSync.Mock() + continuation = CrossSync.Mock(return_value=_make_mock_stream_call([1, 2])) call = continuation.return_value - if CrossSync.is_async: - async def gen(): - yield 1 - yield 2 - call.__aiter__ = mock.Mock(return_value=gen()) - else: - call.__iter__ = mock.Mock(return_value=iter([1, 2])) call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = mock.Mock() @@ -288,19 +292,8 @@ async def test_unary_stream_interceptor_failure_mid_stream(self): op.first_response_latency = None instance.operation_map[op.uuid] = op exc = ValueError("test") - - continuation = CrossSync.Mock() + continuation = CrossSync.Mock(return_value=_make_mock_stream_call([1], exc=exc)) call = continuation.return_value - if CrossSync.is_async: - async def mock_generator(): - yield 1 - raise exc - call.__aiter__ = mock.Mock(return_value=mock_generator()) - else: - def mock_generator(): - yield 1 - raise exc - call.__iter__ = mock.Mock(return_value=mock_generator()) call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = mock.Mock() @@ -445,15 +438,8 @@ async def test_unary_stream_interceptor_start_operation(self, initial_state): op.state = initial_state instance.operation_map[op.uuid] = op - continuation = CrossSync.Mock() + continuation = CrossSync.Mock(return_value=_make_mock_stream_call([1, 2])) call = continuation.return_value - if CrossSync.is_async: - async def gen(): - yield 1 - yield 2 - call.__aiter__ = mock.Mock(return_value=gen()) - else: - call.__iter__ = mock.Mock(return_value=iter([1, 2])) call.trailing_metadata = CrossSync.Mock(return_value=[]) call.initial_metadata = CrossSync.Mock(return_value=[]) details = mock.Mock() diff --git a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py index b9b9bdd78..8990e8693 100644 --- a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py +++ b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py @@ -29,6 +29,20 @@ ) +def _make_mock_stream_call(values, exc=None): + """Create a mock call object that can be used for streaming calls""" + call = CrossSync._Sync_Impl.Mock() + + def gen(): + for val in values: + yield val + if exc: + raise exc + + call.__iter__ = mock.Mock(return_value=gen()) + return call + + class TestMetricsInterceptor: @staticmethod def _get_target_class(): @@ -224,9 +238,10 @@ def test_unary_stream_interceptor_success(self): op.start_time_ns = 0 op.first_response_latency = None instance.operation_map[op.uuid] = op - continuation = CrossSync._Sync_Impl.Mock() + continuation = CrossSync._Sync_Impl.Mock( + return_value=_make_mock_stream_call([1, 2]) + ) call = continuation.return_value - call.__iter__ = mock.Mock(return_value=iter([1, 2])) call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = mock.Mock() @@ -254,14 +269,10 @@ def test_unary_stream_interceptor_failure_mid_stream(self): op.first_response_latency = None instance.operation_map[op.uuid] = op exc = ValueError("test") - continuation = CrossSync._Sync_Impl.Mock() + continuation = CrossSync._Sync_Impl.Mock( + return_value=_make_mock_stream_call([1], exc=exc) + ) call = continuation.return_value - - def mock_generator(): - yield 1 - raise exc - - call.__iter__ = mock.Mock(return_value=mock_generator()) call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = mock.Mock() @@ -397,9 +408,10 @@ def test_unary_stream_interceptor_start_operation(self, initial_state): op.uuid = "test-uuid" op.state = initial_state instance.operation_map[op.uuid] = op - continuation = CrossSync._Sync_Impl.Mock() + continuation = CrossSync._Sync_Impl.Mock( + return_value=_make_mock_stream_call([1, 2]) + ) call = continuation.return_value - call.__iter__ = mock.Mock(return_value=iter([1, 2])) call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) details = mock.Mock() From 01e6b369892116141d96a80b9c95fbe166d3723a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 15:41:10 -0700 Subject: [PATCH 44/60] refactored interceptor --- .../data/_async/metrics_interceptor.py | 65 +++++++++---------- .../data/_sync_autogen/metrics_interceptor.py | 59 +++++++---------- 2 files changed, 53 insertions(+), 71 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 88a9f27fe..a92550c1b 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -15,7 +15,6 @@ import time from functools import wraps -from grpc import RpcError from google.cloud.bigtable.data._metrics.data_model import ( OPERATION_INTERCEPTOR_METADATA_KEY, ) @@ -66,6 +65,26 @@ def wrapper(self, continuation, client_call_details, request): return wrapper +def _end_attempt(operation, exc, metadata): + """Helper to add metadata and exception to an operation""" + if metadata is not None: + operation.add_response_metadata(metadata) + if exc is not None: + # end attempt. If it succeeded, let higher levels decide when to end operation + operation.end_attempt_with_status(exc) + + +@CrossSync.convert +async def _get_metadata(source): + """Helper to extract metadata from a call or RpcError""" + try: + return (await source.trailing_metadata() or []) + ( + await source.initial_metadata() or [] + ) + except Exception: + # ignore errors while fetching metadata + return None + @CrossSync.convert_class(sync_name="BigtableMetricsInterceptor") class AsyncBigtableMetricsInterceptor( UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, MetricsHandler @@ -109,25 +128,14 @@ async def intercept_unary_unary( metadata = None try: call = await continuation(client_call_details, request) - metadata = (await call.trailing_metadata() or []) + (await call.initial_metadata() or []) + metadata = await _get_metadata(call) return call - except RpcError as rpc_error: - # attempt extracting metadata from error - try: - metadata = (await rpc_error.trailing_metadata() or []) + (await rpc_error.initial_metadata() or []) - except Exception: - pass + except Exception as rpc_error: + metadata = await _get_metadata(rpc_error) encountered_exc = rpc_error raise rpc_error - except Exception as e: - encountered_exc = e - raise finally: - if metadata is not None: - operation.add_response_metadata(metadata) - if encountered_exc is not None: - # end attempt. If it succeeded, let higher levels decide when to end operation - operation.end_attempt_with_status(encountered_exc) + _end_attempt(operation, encountered_exc, metadata) @CrossSync.convert @_with_operation_from_metadata @@ -146,30 +154,17 @@ async def response_wrapper(call): ) has_first_response = True yield response - - except Exception as e: + # handle errors while processing stream encountered_exc = e raise finally: if call is not None: - metadata = (await call.trailing_metadata() or []) + (await call.initial_metadata() or []) - operation.add_response_metadata(metadata) - if encountered_exc is not None: - # end attempt. If it succeeded, let higher levels decide when to end operation - operation.end_attempt_with_status(encountered_exc) + _end_attempt(operation, encountered_exc, await _get_metadata(call)) try: return response_wrapper(await continuation(client_call_details, request)) - except RpcError as rpc_error: - # attempt extracting metadata from error - try: - metadata = (await rpc_error.trailing_metadata() or []) + (await rpc_error.initial_metadata() or []) - operation.add_response_metadata(metadata) - except Exception: - pass - operation.end_attempt_with_status(rpc_error) - raise rpc_error - except Exception as e: - operation.end_attempt_with_status(e) - raise \ No newline at end of file + except Exception as rpc_error: + # handle errors while intializing stream + _end_attempt(operation, rpc_error, await _get_metadata(rpc_error)) + raise rpc_error \ No newline at end of file diff --git a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py index 9a14a2cfa..dc262caa9 100644 --- a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py @@ -17,7 +17,6 @@ from __future__ import annotations import time from functools import wraps -from grpc import RpcError from google.cloud.bigtable.data._metrics.data_model import ( OPERATION_INTERCEPTOR_METADATA_KEY, ) @@ -56,6 +55,22 @@ def wrapper(self, continuation, client_call_details, request): return wrapper +def _end_attempt(operation, exc, metadata): + """Helper to add metadata and exception to an operation""" + if metadata is not None: + operation.add_response_metadata(metadata) + if exc is not None: + operation.end_attempt_with_status(exc) + + +def _get_metadata(source): + """Helper to extract metadata from a call or RpcError""" + try: + return (source.trailing_metadata() or []) + (source.initial_metadata() or []) + except Exception: + return None + + class BigtableMetricsInterceptor( UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, MetricsHandler ): @@ -95,27 +110,14 @@ def intercept_unary_unary( metadata = None try: call = continuation(client_call_details, request) - metadata = (call.trailing_metadata() or []) + ( - call.initial_metadata() or [] - ) + metadata = _get_metadata(call) return call - except RpcError as rpc_error: - try: - metadata = (rpc_error.trailing_metadata() or []) + ( - rpc_error.initial_metadata() or [] - ) - except Exception: - pass + except Exception as rpc_error: + metadata = _get_metadata(rpc_error) encountered_exc = rpc_error raise rpc_error - except Exception as e: - encountered_exc = e - raise finally: - if metadata is not None: - operation.add_response_metadata(metadata) - if encountered_exc is not None: - operation.end_attempt_with_status(encountered_exc) + _end_attempt(operation, encountered_exc, metadata) @_with_operation_from_metadata def intercept_unary_stream( @@ -137,25 +139,10 @@ def response_wrapper(call): raise finally: if call is not None: - metadata = (call.trailing_metadata() or []) + ( - call.initial_metadata() or [] - ) - operation.add_response_metadata(metadata) - if encountered_exc is not None: - operation.end_attempt_with_status(encountered_exc) + _end_attempt(operation, encountered_exc, _get_metadata(call)) try: return response_wrapper(continuation(client_call_details, request)) - except RpcError as rpc_error: - try: - metadata = (rpc_error.trailing_metadata() or []) + ( - rpc_error.initial_metadata() or [] - ) - operation.add_response_metadata(metadata) - except Exception: - pass - operation.end_attempt_with_status(rpc_error) + except Exception as rpc_error: + _end_attempt(operation, rpc_error, _get_metadata(rpc_error)) raise rpc_error - except Exception as e: - operation.end_attempt_with_status(e) - raise From b6eac6cfcc5ee4a90fb9b75f1ab56360af883eb4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 26 Aug 2025 15:57:10 -0700 Subject: [PATCH 45/60] fixed lint --- .../data/_async/metrics_interceptor.py | 8 +++++-- tests/unit/data/_async/test_client.py | 8 +++++-- .../data/_async/test_metrics_interceptor.py | 22 +++++++++++++------ .../_sync_autogen/test_metrics_interceptor.py | 1 - 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index a92550c1b..54f8ed107 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -54,7 +54,10 @@ def wrapper(self, continuation, client_call_details, request): operation: "ActiveOperationMetric" = self.operation_map.get(key) if operation: # start a new attempt if not started - if operation.state == OperationState.CREATED or operation.state == OperationState.BETWEEN_ATTEMPTS: + if ( + operation.state == OperationState.CREATED + or operation.state == OperationState.BETWEEN_ATTEMPTS + ): operation.start_attempt() # wrap continuation in logic to process the operation return func(self, operation, continuation, client_call_details, request) @@ -85,6 +88,7 @@ async def _get_metadata(source): # ignore errors while fetching metadata return None + @CrossSync.convert_class(sync_name="BigtableMetricsInterceptor") class AsyncBigtableMetricsInterceptor( UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, MetricsHandler @@ -167,4 +171,4 @@ async def response_wrapper(call): except Exception as rpc_error: # handle errors while intializing stream _end_attempt(operation, rpc_error, await _get_metadata(rpc_error)) - raise rpc_error \ No newline at end of file + raise rpc_error diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 24dfb4430..48ba94ba0 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1160,7 +1160,9 @@ def _make_one( @CrossSync.pytest async def test_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - from google.cloud.bigtable.data._metrics import BigtableClientSideMetricsController + from google.cloud.bigtable.data._metrics import ( + BigtableClientSideMetricsController, + ) expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -1493,7 +1495,9 @@ def _make_one( @CrossSync.pytest async def test_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - from google.cloud.bigtable.data._metrics import BigtableClientSideMetricsController + from google.cloud.bigtable.data._metrics import ( + BigtableClientSideMetricsController, + ) expected_table_id = "table-id" expected_instance_id = "instance-id" diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index ab56e9a37..8510e21a2 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest -import asyncio from grpc import RpcError from google.cloud.bigtable.data._metrics.data_model import OperationState @@ -26,24 +25,31 @@ import mock # type: ignore if CrossSync.is_async: - from google.cloud.bigtable.data._async.metrics_interceptor import AsyncBigtableMetricsInterceptor + from google.cloud.bigtable.data._async.metrics_interceptor import ( + AsyncBigtableMetricsInterceptor, + ) else: - from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import BigtableMetricsInterceptor + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( # noqa: F401 + BigtableMetricsInterceptor, + ) __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_metrics_interceptor" + @CrossSync.convert(replace_symbols={"__aiter__": "__iter__"}) def _make_mock_stream_call(values, exc=None): """ Create a mock call object that can be used for streaming calls """ call = CrossSync.Mock() + async def gen(): for val in values: yield val if exc: raise exc + call.__aiter__ = mock.Mock(return_value=gen()) return call @@ -51,7 +57,11 @@ async def gen(): @CrossSync.convert_class(sync_name="TestMetricsInterceptor") class TestMetricsInterceptorAsync: @staticmethod - @CrossSync.convert(replace_symbols={"AsyncBigtableMetricsInterceptor": "BigtableMetricsInterceptor"}) + @CrossSync.convert( + replace_symbols={ + "AsyncBigtableMetricsInterceptor": "BigtableMetricsInterceptor" + } + ) def _get_target_class(): return AsyncBigtableMetricsInterceptor @@ -170,7 +180,6 @@ async def test_unary_unary_interceptor_failure(self): exc.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) exc.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) continuation = CrossSync.Mock(side_effect=exc) - call = continuation.return_value details = mock.Mock() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() @@ -235,7 +244,6 @@ async def test_unary_unary_interceptor_failure_generic(self): op.add_response_metadata.assert_not_called() op.end_attempt_with_status.assert_called_once_with(exc) - @CrossSync.pytest async def test_unary_stream_interceptor_op_not_found(self): """Test that interceptor calls continuation if op is not found""" @@ -446,4 +454,4 @@ async def test_unary_stream_interceptor_start_operation(self, initial_state): details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() await instance.intercept_unary_stream(continuation, details, request) - op.start_attempt.assert_called_once() \ No newline at end of file + op.start_attempt.assert_called_once() diff --git a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py index 8990e8693..e545c3503 100644 --- a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py +++ b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py @@ -152,7 +152,6 @@ def test_unary_unary_interceptor_failure(self): exc.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) exc.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) - call = continuation.return_value details = mock.Mock() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() From 4ccfdabb6c5553c9885c595061f188bea9dd9553 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Sep 2025 14:45:41 -0700 Subject: [PATCH 46/60] removed duplicate import --- google/cloud/bigtable/data/_async/client.py | 6 ------ google/cloud/bigtable/data/_sync_autogen/client.py | 3 --- 2 files changed, 9 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 1944cf016..d551f5b7a 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -98,9 +98,6 @@ BigtableAsyncClient as GapicClient, ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._async.metrics_interceptor import ( - AsyncBigtableMetricsInterceptor as MetricInterceptorType, - ) from google.cloud.bigtable.data._async._swappable_channel import ( AsyncSwappableChannel as SwappableChannelType, ) @@ -114,9 +111,6 @@ from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE - from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( - BigtableMetricsInterceptor as MetricInterceptorType, - ) from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( # noqa: F401 SwappableChannel as SwappableChannelType, ) diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 19aca8ee1..bcd13c068 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -82,9 +82,6 @@ ) from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE -from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( - BigtableMetricsInterceptor as MetricInterceptorType, -) from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( SwappableChannel as SwappableChannelType, ) From bd9ab70022476a05d8c4841d5bd86f399a00ef59 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Sep 2025 15:38:07 -0700 Subject: [PATCH 47/60] added more tests --- .../bigtable/data/_metrics/data_model.py | 5 ++ tests/unit/data/_metrics/test_data_model.py | 51 ++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 2bcd40021..be9e34fc4 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -165,6 +165,11 @@ class ActiveOperationMetric: @property def interceptor_metadata(self) -> tuple[str, str]: + """ + returns a tuple to attach to the grpc metadata. + + This metadata field will be read by the BigtableMetricsInterceptor to associate a request with an operation + """ return OPERATION_INTERCEPTOR_METADATA_KEY, self.uuid @property diff --git a/tests/unit/data/_metrics/test_data_model.py b/tests/unit/data/_metrics/test_data_model.py index b281ccec0..868e02719 100644 --- a/tests/unit/data/_metrics/test_data_model.py +++ b/tests/unit/data/_metrics/test_data_model.py @@ -391,6 +391,7 @@ def test_end_attempt_with_status(self): - add one to completed_attempts - reset active_attempt to None - update state + - notify handlers """ expected_start_time = 1 expected_status = object() @@ -398,8 +399,9 @@ def test_end_attempt_with_status(self): expected_app_blocking = 12 expected_backoff = 2 expected_grpc_throttle = 3 + handlers = [mock.Mock(), mock.Mock()] - metric = self._make_one(mock.Mock()) + metric = self._make_one(mock.Mock(), handlers=handlers) assert metric.active_attempt is None assert len(metric.completed_attempts) == 0 metric.start_attempt() @@ -420,6 +422,11 @@ def test_end_attempt_with_status(self): assert got_attempt.backoff_before_attempt_ns == expected_backoff # state should be changed to BETWEEN_ATTEMPTS assert metric.state == State.BETWEEN_ATTEMPTS + # check handlers + for h in handlers: + assert h.on_attempt_complete.call_count == 1 + assert h.on_attempt_complete.call_args[0][0] == got_attempt + assert h.on_attempt_complete.call_args[0][1] == metric def test_end_attempt_with_status_w_exception(self): """ @@ -528,6 +535,48 @@ def test_end_with_status_w_exception(self): final_op = handlers[0].on_operation_complete.call_args[0][0] assert final_op.final_status == expected_status + def test_interceptor_metadata(self): + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + metric = self._make_one(mock.Mock()) + key, value = metric.interceptor_metadata + assert key == OPERATION_INTERCEPTOR_METADATA_KEY + assert value == metric.uuid + + def test_cancel(self): + """ + cancel should call on_operation_cancelled on handlers + """ + handlers = [mock.Mock(), mock.Mock()] + metric = self._make_one(mock.Mock(), handlers=handlers) + metric.cancel() + for h in handlers: + assert h.on_operation_cancelled.call_count == 1 + assert h.on_operation_cancelled.call_args[0][0] == metric + + def test_end_with_status_with_default_cluster_zone(self): + """ + ending the operation should use default cluster and zone if not set + """ + from google.cloud.bigtable.data._metrics.data_model import ( + DEFAULT_CLUSTER_ID, + DEFAULT_ZONE, + ) + + handlers = [mock.Mock()] + metric = self._make_one(mock.Mock(), handlers=handlers) + assert metric.cluster_id is None + assert metric.zone is None + metric.end_with_status(mock.Mock()) + assert metric.state == State.COMPLETED + # check that finalized operation was passed to handlers + for h in handlers: + assert h.on_operation_complete.call_count == 1 + called_with = h.on_operation_complete.call_args[0][0] + assert called_with.cluster_id == DEFAULT_CLUSTER_ID + assert called_with.zone == DEFAULT_ZONE def test_end_with_success(self): """ end with success should be a pass-through helper for end_with_status From 50b3e48c3415263e8f57a0f3077f740b3b270622 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Sep 2025 16:26:25 -0700 Subject: [PATCH 48/60] remove operation metadata key --- .../data/_async/metrics_interceptor.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 54f8ed107..10c5f3f25 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -43,16 +43,22 @@ def _with_operation_from_metadata(func): @wraps(func) def wrapper(self, continuation, client_call_details, request): - key = next( - ( - m[1] - for m in client_call_details.metadata - if m[0] == OPERATION_INTERCEPTOR_METADATA_KEY - ), - None, - ) - operation: "ActiveOperationMetric" = self.operation_map.get(key) + found_operation_id: str | None = None + new_metadata = client_call_details.metadata + if client_call_details.metadata: + # find operation key and strip it from metadata + temp_metadata = [] + for k, v in client_call_details.metadata: + if k == OPERATION_INTERCEPTOR_METADATA_KEY: + found_operation_id = v + else: + temp_metadata.append((k, v)) + new_metadata = temp_metadata + + operation: "ActiveOperationMetric" = self.operation_map.get(found_operation_id) if operation: + # create new client_call_details without the operation key + client_call_details = client_call_details._replace(metadata=new_metadata) # start a new attempt if not started if ( operation.state == OperationState.CREATED From 2b35127f8557e64bb904d0c1fca2367225270a7f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Sep 2025 16:45:36 -0700 Subject: [PATCH 49/60] assign metadata directly --- google/cloud/bigtable/data/_async/metrics_interceptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 10c5f3f25..e524b8104 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -58,7 +58,7 @@ def wrapper(self, continuation, client_call_details, request): operation: "ActiveOperationMetric" = self.operation_map.get(found_operation_id) if operation: # create new client_call_details without the operation key - client_call_details = client_call_details._replace(metadata=new_metadata) + client_call_details.metadata = new_metadata # start a new attempt if not started if ( operation.state == OperationState.CREATED From 9cbda9975e24ecdfea71af94dd21f6ad4392216d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Sep 2025 16:45:52 -0700 Subject: [PATCH 50/60] added test --- .../data/_async/test_metrics_interceptor.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index 8510e21a2..d56f6eeb9 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -14,6 +14,7 @@ import pytest from grpc import RpcError +from grpc import ClientCallDetails from google.cloud.bigtable.data._metrics.data_model import OperationState from google.cloud.bigtable.data._cross_sync import CrossSync @@ -128,6 +129,28 @@ def test_on_operation_cancelled(self): op.cancel() assert op.uuid not in instance.operation_map + @CrossSync.pytest + async def test_strip_operation_id_metadata(self): + """ + After operation id is detected in metadata, the field should be stripped out before calling continuation + """ + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + continuation = CrossSync.Mock() + details = ClientCallDetails() + details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid), ("other_key", "other_value")] + await instance.intercept_unary_unary(continuation, details, mock.Mock()) + assert details.metadata == [("other_key", "other_value")] + assert continuation.call_count == 1 + assert continuation.call_args[0][0].metadata == [("other_key", "other_value")] + @CrossSync.pytest async def test_unary_unary_interceptor_op_not_found(self): """Test that interceptor call cuntinuation if op is not found""" From 73f4b3cb23a8158d250e720680ef52ae6cf09243 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Sep 2025 16:46:42 -0700 Subject: [PATCH 51/60] replace details mocks with real type --- .../data/_async/test_metrics_interceptor.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index d56f6eeb9..b481850ed 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -156,7 +156,7 @@ async def test_unary_unary_interceptor_op_not_found(self): """Test that interceptor call cuntinuation if op is not found""" instance = self._make_one() continuation = CrossSync.Mock() - details = mock.Mock() + details = ClientCallDetails() details.metadata = [] request = mock.Mock() await instance.intercept_unary_unary(continuation, details, request) @@ -178,7 +178,7 @@ async def test_unary_unary_interceptor_success(self): call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() result = await instance.intercept_unary_unary(continuation, details, request) @@ -203,7 +203,7 @@ async def test_unary_unary_interceptor_failure(self): exc.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) exc.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) continuation = CrossSync.Mock(side_effect=exc) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: @@ -230,7 +230,7 @@ async def test_unary_unary_interceptor_failure_no_metadata(self): call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: @@ -257,7 +257,7 @@ async def test_unary_unary_interceptor_failure_generic(self): call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(ValueError) as e: @@ -272,7 +272,7 @@ async def test_unary_stream_interceptor_op_not_found(self): """Test that interceptor calls continuation if op is not found""" instance = self._make_one() continuation = CrossSync.Mock() - details = mock.Mock() + details = ClientCallDetails() details.metadata = [] request = mock.Mock() await instance.intercept_unary_stream(continuation, details, request) @@ -297,7 +297,7 @@ async def test_unary_stream_interceptor_success(self): call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() wrapper = await instance.intercept_unary_stream(continuation, details, request) @@ -327,7 +327,7 @@ async def test_unary_stream_interceptor_failure_mid_stream(self): call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() wrapper = await instance.intercept_unary_stream(continuation, details, request) @@ -359,7 +359,7 @@ async def test_unary_stream_interceptor_failure_start_stream(self): continuation = CrossSync.Mock() continuation.side_effect = exc - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: @@ -388,7 +388,7 @@ async def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): continuation = CrossSync.Mock() continuation.side_effect = exc - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: @@ -417,7 +417,7 @@ async def test_unary_stream_interceptor_failure_start_stream_generic(self): continuation = CrossSync.Mock() continuation.side_effect = exc - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(ValueError) as e: @@ -447,7 +447,7 @@ async def test_unary_unary_interceptor_start_operation(self, initial_state): call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[]) call.initial_metadata = CrossSync.Mock(return_value=[]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() await instance.intercept_unary_unary(continuation, details, request) @@ -473,7 +473,7 @@ async def test_unary_stream_interceptor_start_operation(self, initial_state): call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[]) call.initial_metadata = CrossSync.Mock(return_value=[]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() await instance.intercept_unary_stream(continuation, details, request) From 486068b99c4773b44688638b59a08c9d18b60dd0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Sep 2025 16:55:48 -0700 Subject: [PATCH 52/60] added try; generated sync --- .../data/_async/metrics_interceptor.py | 25 +++++----- .../bigtable/data/_metrics/data_model.py | 2 +- .../data/_sync_autogen/metrics_interceptor.py | 22 +++++---- .../data/_async/test_metrics_interceptor.py | 5 +- tests/unit/data/_metrics/test_data_model.py | 1 + .../_sync_autogen/test_metrics_interceptor.py | 49 ++++++++++++++----- 6 files changed, 68 insertions(+), 36 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index c900f2c49..dc99d2fad 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -44,18 +44,19 @@ def _with_operation_from_metadata(func): @wraps(func) def wrapper(self, continuation, client_call_details, request): found_operation_id: str | None = None - new_metadata = client_call_details.metadata - if client_call_details.metadata: - # find operation key from metadata - temp_metadata = [] - for k, v in client_call_details.metadata: - if k == OPERATION_INTERCEPTOR_METADATA_KEY: - found_operation_id = v - else: - temp_metadata.append((k, v)) - new_metadata = temp_metadata - # update client_call_details to drop the operation key metadata - client_call_details.metadata = new_metadata + try: + new_metadata: list[tuple[str, str]] = [] + if client_call_details.metadata: + # find operation key from metadata + for k, v in client_call_details.metadata: + if k == OPERATION_INTERCEPTOR_METADATA_KEY: + found_operation_id = v + else: + new_metadata.append((k, v)) + # update client_call_details to drop the operation key metadata + client_call_details.metadata = new_metadata + except Exception: + pass operation: "ActiveOperationMetric" = self.operation_map.get(found_operation_id) if operation: diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index be9e34fc4..6c4572d24 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -166,7 +166,7 @@ class ActiveOperationMetric: @property def interceptor_metadata(self) -> tuple[str, str]: """ - returns a tuple to attach to the grpc metadata. + returns a tuple to attach to the grpc metadata. This metadata field will be read by the BigtableMetricsInterceptor to associate a request with an operation """ diff --git a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py index dc262caa9..4cefed824 100644 --- a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py @@ -33,15 +33,19 @@ def _with_operation_from_metadata(func): @wraps(func) def wrapper(self, continuation, client_call_details, request): - key = next( - ( - m[1] - for m in client_call_details.metadata - if m[0] == OPERATION_INTERCEPTOR_METADATA_KEY - ), - None, - ) - operation: "ActiveOperationMetric" = self.operation_map.get(key) + found_operation_id: str | None = None + try: + new_metadata: list[tuple[str, str]] = [] + if client_call_details.metadata: + for k, v in client_call_details.metadata: + if k == OPERATION_INTERCEPTOR_METADATA_KEY: + found_operation_id = v + else: + new_metadata.append((k, v)) + client_call_details.metadata = new_metadata + except Exception: + pass + operation: "ActiveOperationMetric" = self.operation_map.get(found_operation_id) if operation: if ( operation.state == OperationState.CREATED diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index b481850ed..caa9bbb46 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -145,7 +145,10 @@ async def test_strip_operation_id_metadata(self): instance.operation_map[op.uuid] = op continuation = CrossSync.Mock() details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid), ("other_key", "other_value")] + details.metadata = [ + (OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid), + ("other_key", "other_value"), + ] await instance.intercept_unary_unary(continuation, details, mock.Mock()) assert details.metadata == [("other_key", "other_value")] assert continuation.call_count == 1 diff --git a/tests/unit/data/_metrics/test_data_model.py b/tests/unit/data/_metrics/test_data_model.py index 868e02719..7d9b6671f 100644 --- a/tests/unit/data/_metrics/test_data_model.py +++ b/tests/unit/data/_metrics/test_data_model.py @@ -577,6 +577,7 @@ def test_end_with_status_with_default_cluster_zone(self): called_with = h.on_operation_complete.call_args[0][0] assert called_with.cluster_id == DEFAULT_CLUSTER_ID assert called_with.zone == DEFAULT_ZONE + def test_end_with_success(self): """ end with success should be a pass-through helper for end_with_status diff --git a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py index e545c3503..283814e27 100644 --- a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py +++ b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py @@ -17,6 +17,7 @@ import pytest from grpc import RpcError +from grpc import ClientCallDetails from google.cloud.bigtable.data._metrics.data_model import OperationState from google.cloud.bigtable.data._cross_sync import CrossSync @@ -103,11 +104,33 @@ def test_on_operation_cancelled(self): op.cancel() assert op.uuid not in instance.operation_map + def test_strip_operation_id_metadata(self): + """After operation id is detected in metadata, the field should be stripped out before calling continuation""" + from google.cloud.bigtable.data._metrics.data_model import ( + OPERATION_INTERCEPTOR_METADATA_KEY, + ) + + instance = self._make_one() + op = mock.Mock() + op.uuid = "test-uuid" + op.state = OperationState.ACTIVE_ATTEMPT + instance.operation_map[op.uuid] = op + continuation = CrossSync._Sync_Impl.Mock() + details = ClientCallDetails() + details.metadata = [ + (OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid), + ("other_key", "other_value"), + ] + instance.intercept_unary_unary(continuation, details, mock.Mock()) + assert details.metadata == [("other_key", "other_value")] + assert continuation.call_count == 1 + assert continuation.call_args[0][0].metadata == [("other_key", "other_value")] + def test_unary_unary_interceptor_op_not_found(self): """Test that interceptor call cuntinuation if op is not found""" instance = self._make_one() continuation = CrossSync._Sync_Impl.Mock() - details = mock.Mock() + details = ClientCallDetails() details.metadata = [] request = mock.Mock() instance.intercept_unary_unary(continuation, details, request) @@ -128,7 +151,7 @@ def test_unary_unary_interceptor_success(self): call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() result = instance.intercept_unary_unary(continuation, details, request) @@ -152,7 +175,7 @@ def test_unary_unary_interceptor_failure(self): exc.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) exc.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: @@ -178,7 +201,7 @@ def test_unary_unary_interceptor_failure_no_metadata(self): call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: @@ -204,7 +227,7 @@ def test_unary_unary_interceptor_failure_generic(self): call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(ValueError) as e: @@ -218,7 +241,7 @@ def test_unary_stream_interceptor_op_not_found(self): """Test that interceptor calls continuation if op is not found""" instance = self._make_one() continuation = CrossSync._Sync_Impl.Mock() - details = mock.Mock() + details = ClientCallDetails() details.metadata = [] request = mock.Mock() instance.intercept_unary_stream(continuation, details, request) @@ -243,7 +266,7 @@ def test_unary_stream_interceptor_success(self): call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() wrapper = instance.intercept_unary_stream(continuation, details, request) @@ -274,7 +297,7 @@ def test_unary_stream_interceptor_failure_mid_stream(self): call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() wrapper = instance.intercept_unary_stream(continuation, details, request) @@ -304,7 +327,7 @@ def test_unary_stream_interceptor_failure_start_stream(self): exc.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) continuation = CrossSync._Sync_Impl.Mock() continuation.side_effect = exc - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: @@ -331,7 +354,7 @@ def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): exc = RpcError("test") continuation = CrossSync._Sync_Impl.Mock() continuation.side_effect = exc - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: @@ -358,7 +381,7 @@ def test_unary_stream_interceptor_failure_start_stream_generic(self): exc = ValueError("test") continuation = CrossSync._Sync_Impl.Mock() continuation.side_effect = exc - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(ValueError) as e: @@ -387,7 +410,7 @@ def test_unary_unary_interceptor_start_operation(self, initial_state): call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() instance.intercept_unary_unary(continuation, details, request) @@ -413,7 +436,7 @@ def test_unary_stream_interceptor_start_operation(self, initial_state): call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) - details = mock.Mock() + details = ClientCallDetails() details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() instance.intercept_unary_stream(continuation, details, request) From bebeb70557c5016a0a2c6d61bdf5533b2cefe5f7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 29 Sep 2025 22:56:40 -0700 Subject: [PATCH 53/60] use contextvars --- .../data/_async/metrics_interceptor.py | 46 +------------------ .../bigtable/data/_metrics/data_model.py | 22 +++++---- .../data/_metrics/metrics_controller.py | 21 +-------- 3 files changed, 15 insertions(+), 74 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index dc99d2fad..2bee4de43 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -15,9 +15,6 @@ import time from functools import wraps -from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, -) from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric from google.cloud.bigtable.data._metrics.data_model import OperationState from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler @@ -43,22 +40,7 @@ def _with_operation_from_metadata(func): @wraps(func) def wrapper(self, continuation, client_call_details, request): - found_operation_id: str | None = None - try: - new_metadata: list[tuple[str, str]] = [] - if client_call_details.metadata: - # find operation key from metadata - for k, v in client_call_details.metadata: - if k == OPERATION_INTERCEPTOR_METADATA_KEY: - found_operation_id = v - else: - new_metadata.append((k, v)) - # update client_call_details to drop the operation key metadata - client_call_details.metadata = new_metadata - except Exception: - pass - - operation: "ActiveOperationMetric" = self.operation_map.get(found_operation_id) + operation: "ActiveOperationMetric" | None = ActiveOperationMetric.get_active() if operation: # start a new attempt if not started if ( @@ -104,32 +86,6 @@ class AsyncBigtableMetricsInterceptor( An async gRPC interceptor to add client metadata and print server metadata. """ - def __init__(self): - super().__init__() - self.operation_map = {} - - def register_operation(self, operation): - """ - Register an operation object to be tracked my the interceptor - - When registered, the operation will receive metadata updates: - - start_attempt if attempt not started when rpc is being sent - - add_response_metadata after call is complete - - end_attempt_with_status if attempt receives an error - - The interceptor will register itself as a handeler for the operation, - so it can unregister the operation when it is complete - """ - self.operation_map[operation.uuid] = operation - operation.handlers.append(self) - - def on_operation_complete(self, op): - if op.uuid in self.operation_map: - del self.operation_map[op.uuid] - - def on_operation_cancelled(self, op): - self.on_operation_complete(op) - @CrossSync.convert @_with_operation_from_metadata async def intercept_unary_unary( diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 6c4572d24..f840f512b 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -13,12 +13,13 @@ # limitations under the License. from __future__ import annotations -from typing import Tuple, cast, TYPE_CHECKING +from typing import Callable, ClassVar, List, Tuple, Optional, cast, TYPE_CHECKING import time import re import logging import uuid +import contextvars from enum import Enum from functools import lru_cache @@ -48,7 +49,6 @@ INVALID_STATE_ERROR = "Invalid state for {}: {}" -OPERATION_INTERCEPTOR_METADATA_KEY = "x-goog-operation-key" class OperationType(Enum): @@ -163,14 +163,13 @@ class ActiveOperationMetric: # time waiting on flow control, in nanoseconds flow_throttling_time_ns: int = 0 - @property - def interceptor_metadata(self) -> tuple[str, str]: - """ - returns a tuple to attach to the grpc metadata. + _active_operation_context: ClassVar[ + contextvars.ContextVar[ActiveOperationMetric] + ] = contextvars.ContextVar("active_operation_context") - This metadata field will be read by the BigtableMetricsInterceptor to associate a request with an operation - """ - return OPERATION_INTERCEPTOR_METADATA_KEY, self.uuid + @classmethod + def get_active(cls): + return cls._active_operation_context.get(None) @property def state(self) -> OperationState: @@ -184,6 +183,9 @@ def state(self) -> OperationState: else: return OperationState.ACTIVE_ATTEMPT + def __post_init__(self): + self._active_operation_context.set(self) + def start(self) -> None: """ Optionally called to mark the start of the operation. If not called, @@ -194,6 +196,7 @@ def start(self) -> None: if self.state != OperationState.CREATED: return self._handle_error(INVALID_STATE_ERROR.format("start", self.state)) self.start_time_ns = time.monotonic_ns() + self._active_operation_context.set(self) def start_attempt(self) -> ActiveAttemptMetric | None: """ @@ -208,6 +211,7 @@ def start_attempt(self) -> ActiveAttemptMetric | None: return self._handle_error( INVALID_STATE_ERROR.format("start_attempt", self.state) ) + self._active_operation_context.set(self) try: # find backoff value before this attempt diff --git a/google/cloud/bigtable/data/_metrics/metrics_controller.py b/google/cloud/bigtable/data/_metrics/metrics_controller.py index f13590f7c..8e7eb373d 100644 --- a/google/cloud/bigtable/data/_metrics/metrics_controller.py +++ b/google/cloud/bigtable/data/_metrics/metrics_controller.py @@ -19,14 +19,6 @@ from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler from google.cloud.bigtable.data._metrics.data_model import OperationType -if TYPE_CHECKING: - from google.cloud.bigtable.data._async.metrics_interceptor import ( - AsyncBigtableMetricsInterceptor, - ) - from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( - BigtableMetricsInterceptor, - ) - class BigtableClientSideMetricsController: """ @@ -38,24 +30,15 @@ class BigtableClientSideMetricsController: def __init__( self, - interceptor: AsyncBigtableMetricsInterceptor | BigtableMetricsInterceptor, handlers: list[MetricsHandler] | None = None, - **kwargs, ): """ Initializes the metrics controller. Args: - - interceptor: A metrics interceptor to use for triggering Operation lifecycle events - handlers: A list of MetricsHandler objects to subscribe to metrics events. - - **kwargs: Optional arguments to pass to the metrics handlers. """ - self.interceptor = interceptor self.handlers: list[MetricsHandler] = handlers or [] - if handlers is None: - # handlers not given. Use default handlers. - # TODO: add default handlers - pass def add_handler(self, handler: MetricsHandler) -> None: """ @@ -72,6 +55,4 @@ def create_operation( """ Creates a new operation and registers it with the subscribed handlers. """ - new_op = ActiveOperationMetric(op_type, **kwargs, handlers=self.handlers) - self.interceptor.register_operation(new_op) - return new_op + return ActiveOperationMetric(op_type, **kwargs, handlers=self.handlers) From 5ba2bbe9960d8af83e29cec103df480f183412be Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 29 Sep 2025 23:03:34 -0700 Subject: [PATCH 54/60] pulled in improvements to data model --- .../data/_async/metrics_interceptor.py | 113 ++++++++++++------ .../bigtable/data/_metrics/data_model.py | 95 +++++++++++++-- 2 files changed, 160 insertions(+), 48 deletions(-) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py index 2bee4de43..0bd401a78 100644 --- a/google/cloud/bigtable/data/_async/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -13,10 +13,14 @@ # limitations under the License from __future__ import annotations +from typing import Sequence + import time from functools import wraps + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric from google.cloud.bigtable.data._metrics.data_model import OperationState +from google.cloud.bigtable.data._metrics.data_model import OperationType from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler from google.cloud.bigtable.data._cross_sync import CrossSync @@ -24,6 +28,7 @@ if CrossSync.is_async: from grpc.aio import UnaryUnaryClientInterceptor from grpc.aio import UnaryStreamClientInterceptor + from grpc.aio import AioRpcError else: from grpc import UnaryUnaryClientInterceptor from grpc import UnaryStreamClientInterceptor @@ -41,6 +46,7 @@ def _with_operation_from_metadata(func): @wraps(func) def wrapper(self, continuation, client_call_details, request): operation: "ActiveOperationMetric" | None = ActiveOperationMetric.get_active() + if operation: # start a new attempt if not started if ( @@ -57,22 +63,26 @@ def wrapper(self, continuation, client_call_details, request): return wrapper -def _end_attempt(operation, exc, metadata): - """Helper to add metadata and exception to an operation""" - if metadata is not None: - operation.add_response_metadata(metadata) - if exc is not None: - # end attempt. If it succeeded, let higher levels decide when to end operation - operation.end_attempt_with_status(exc) - - @CrossSync.convert -async def _get_metadata(source): +async def _get_metadata(source) -> dict[str, str | bytes] | None: """Helper to extract metadata from a call or RpcError""" try: - return (await source.trailing_metadata() or []) + ( - await source.initial_metadata() or [] - ) + metadata: Sequence[tuple[str, str | bytes]] + if CrossSync.is_async: + # grpc.aio returns metadata in Metadata objects + if isinstance(source, AioRpcError): + metadata = list(source.trailing_metadata()) + list( + source.initial_metadata() + ) + else: + metadata = list(await source.trailing_metadata()) + list( + await source.initial_metadata() + ) + else: + # sync grpc returns metadata as a sequence of tuples + metadata = source.trailing_metadata() + source.initial_metadata() + # convert metadata to dict format + return {k: v for (k, v) in metadata} except Exception: # ignore errors while fetching metadata return None @@ -91,7 +101,12 @@ class AsyncBigtableMetricsInterceptor( async def intercept_unary_unary( self, operation, continuation, client_call_details, request ): - encountered_exc: Exception | None = None + """ + Interceptor for unary rpcs: + - MutateRow + - CheckAndMutateRow + - ReadModifyWriteRow + """ metadata = None try: call = await continuation(client_call_details, request) @@ -99,39 +114,59 @@ async def intercept_unary_unary( return call except Exception as rpc_error: metadata = await _get_metadata(rpc_error) - encountered_exc = rpc_error raise rpc_error finally: - _end_attempt(operation, encountered_exc, metadata) + if metadata is not None: + operation.add_response_metadata(metadata) @CrossSync.convert @_with_operation_from_metadata async def intercept_unary_stream( self, operation, continuation, client_call_details, request ): - async def response_wrapper(call): - has_first_response = operation.first_response_latency is not None - encountered_exc = None - try: - async for response in call: - # record time to first response. Currently only used for READ_ROWs - if not has_first_response: - operation.first_response_latency_ns = ( - time.monotonic_ns() - operation.start_time_ns - ) - has_first_response = True - yield response - except Exception as e: - # handle errors while processing stream - encountered_exc = e - raise - finally: - if call is not None: - _end_attempt(operation, encountered_exc, await _get_metadata(call)) - + """ + Interceptor for streaming rpcs: + - ReadRows + - MutateRows + - SampleRowKeys + """ try: - return response_wrapper(await continuation(client_call_details, request)) + return self._streaming_generator_wrapper( + operation, await continuation(client_call_details, request) + ) except Exception as rpc_error: - # handle errors while intializing stream - _end_attempt(operation, rpc_error, await _get_metadata(rpc_error)) + metadata = await _get_metadata(rpc_error) + if metadata is not None: + operation.add_response_metadata(metadata) raise rpc_error + + @staticmethod + @CrossSync.convert + async def _streaming_generator_wrapper(operation, call): + """ + Wrapped generator to be returned by intercept_unary_stream + """ + # only track has_first response for READ_ROWS + has_first_response = ( + operation.first_response_latency_ns is not None + or operation.op_type != OperationType.READ_ROWS + ) + encountered_exc = None + try: + async for response in call: + # record time to first response. Currently only used for READ_ROWs + if not has_first_response: + operation.first_response_latency_ns = ( + time.monotonic_ns() - operation.start_time_ns + ) + has_first_response = True + yield response + except Exception as e: + # handle errors while processing stream + encountered_exc = e + raise + finally: + if call is not None: + metadata = await _get_metadata(encountered_exc or call) + if metadata is not None: + operation.add_response_metadata(metadata) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index f840f512b..041f1de4a 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -26,10 +26,15 @@ from dataclasses import dataclass from dataclasses import field from grpc import StatusCode +from grpc import RpcError +from grpc.aio import AioRpcError +from google.api_core.exceptions import GoogleAPICallError +from google.api_core.retry import RetryFailureReason import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable_v2.types.response_params import ResponseParams from google.cloud.bigtable.data._helpers import TrackedBackoffGenerator +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.protobuf.message import DecodeError if TYPE_CHECKING: @@ -49,6 +54,10 @@ INVALID_STATE_ERROR = "Invalid state for {}: {}" +ExceptionFactoryType = Callable[ + [List[Exception], RetryFailureReason, Optional[float]], + Tuple[Exception, Optional[Exception]], +] class OperationType(Enum): @@ -284,7 +293,7 @@ def _parse_response_metadata_blob(blob: bytes) -> Tuple[str, str] | None: # failed to parse metadata return None - def end_attempt_with_status(self, status: StatusCode | Exception) -> None: + def end_attempt_with_status(self, status: StatusCode | BaseException) -> None: """ Called to mark the end of an attempt for the operation. @@ -301,7 +310,7 @@ def end_attempt_with_status(self, status: StatusCode | Exception) -> None: return self._handle_error( INVALID_STATE_ERROR.format("end_attempt_with_status", self.state) ) - if isinstance(status, Exception): + if isinstance(status, BaseException): status = self._exc_to_status(status) complete_attempt = CompletedAttemptMetric( duration_ns=time.monotonic_ns() - self.active_attempt.start_time_ns, @@ -316,7 +325,7 @@ def end_attempt_with_status(self, status: StatusCode | Exception) -> None: for handler in self.handlers: handler.on_attempt_complete(complete_attempt, self) - def end_with_status(self, status: StatusCode | Exception) -> None: + def end_with_status(self, status: StatusCode | BaseException) -> None: """ Called to mark the end of the operation. If there is an active attempt, end_attempt_with_status will be called with the same status. @@ -333,7 +342,7 @@ def end_with_status(self, status: StatusCode | Exception) -> None: INVALID_STATE_ERROR.format("end_with_status", self.state) ) final_status = ( - self._exc_to_status(status) if isinstance(status, Exception) else status + self._exc_to_status(status) if isinstance(status, BaseException) else status ) if self.state == OperationState.ACTIVE_ATTEMPT: self.end_attempt_with_status(final_status) @@ -371,7 +380,7 @@ def cancel(self): handler.on_operation_cancelled(self) @staticmethod - def _exc_to_status(exc: Exception) -> StatusCode: + def _exc_to_status(exc: BaseException) -> StatusCode: """ Extracts the grpc status code from an exception. @@ -393,8 +402,73 @@ def _exc_to_status(exc: Exception) -> StatusCode: and exc.__cause__.grpc_status_code is not None ): return exc.__cause__.grpc_status_code + if isinstance(exc, AioRpcError) or isinstance(exc, RpcError): + return exc.code() return StatusCode.UNKNOWN + def track_retryable_error(self, exc: Exception) -> None: + """ + Used as input to api_core.Retry classes, to track when retryable errors are encountered + + Should be passed as on_error callback + """ + try: + # record metadata from failed rpc + if isinstance(exc, GoogleAPICallError) and exc.errors: + rpc_error = exc.errors[-1] + metadata = list(rpc_error.trailing_metadata()) + list( + rpc_error.initial_metadata() + ) + self.add_response_metadata({k: v for k, v in metadata}) + except Exception: + # ignore errors in metadata collection + pass + if isinstance(exc, _MutateRowsIncomplete): + # _MutateRowsIncomplete represents a successful rpc with some failed mutations + # mark the attempt as successful + self.end_attempt_with_status(StatusCode.OK) + else: + self.end_attempt_with_status(exc) + + def track_terminal_error( + self, exception_factory: ExceptionFactoryType + ) -> ExceptionFactoryType: + """ + Used as input to api_core.Retry classes, to track when terminal errors are encountered + + Should be used as a wrapper over an exception_factory callback + """ + + def wrapper( + exc_list: list[Exception], + reason: RetryFailureReason, + timeout_val: float | None, + ) -> tuple[Exception, Exception | None]: + source_exc, cause_exc = exception_factory(exc_list, reason, timeout_val) + try: + # record metadata from failed rpc + if isinstance(source_exc, GoogleAPICallError) and source_exc.errors: + rpc_error = source_exc.errors[-1] + metadata = list(rpc_error.trailing_metadata()) + list( + rpc_error.initial_metadata() + ) + self.add_response_metadata({k: v for k, v in metadata}) + except Exception: + # ignore errors in metadata collection + pass + if ( + reason == RetryFailureReason.TIMEOUT + and self.state == OperationState.ACTIVE_ATTEMPT + and exc_list + ): + # record ending attempt for timeout failures + attempt_exc = exc_list[-1] + self.track_retryable_error(attempt_exc) + self.end_with_status(source_exc) + return source_exc, cause_exc + + return wrapper + @staticmethod def _handle_error(message: str) -> None: """ @@ -422,8 +496,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): The operation is automatically ended on exit, with the status determined by the exception type and value. + + If operation was already ended manually, do nothing. """ - if exc_val is None: - self.end_with_success() - else: - self.end_with_status(exc_val) + if not self.state == OperationState.COMPLETED: + if exc_val is None: + self.end_with_success() + else: + self.end_with_status(exc_val) From 4098fd9fef4ac8de341d0d560712749f27354bba Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 29 Sep 2025 23:04:18 -0700 Subject: [PATCH 55/60] removed cancel from spec --- google/cloud/bigtable/data/_metrics/data_model.py | 7 ------- google/cloud/bigtable/data/_metrics/handlers/_base.py | 3 --- 2 files changed, 10 deletions(-) diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index 041f1de4a..d0d9b5f52 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -372,13 +372,6 @@ def end_with_success(self): """ return self.end_with_status(StatusCode.OK) - def cancel(self): - """ - Called to cancel an operation without processing emitting it. - """ - for handler in self.handlers: - handler.on_operation_cancelled(self) - @staticmethod def _exc_to_status(exc: BaseException) -> StatusCode: """ diff --git a/google/cloud/bigtable/data/_metrics/handlers/_base.py b/google/cloud/bigtable/data/_metrics/handlers/_base.py index 64cc89b05..72f5aa550 100644 --- a/google/cloud/bigtable/data/_metrics/handlers/_base.py +++ b/google/cloud/bigtable/data/_metrics/handlers/_base.py @@ -29,9 +29,6 @@ def __init__(self, **kwargs): def on_operation_complete(self, op: CompletedOperationMetric) -> None: pass - def on_operation_cancelled(self, op: ActiveOperationMetric) -> None: - pass - def on_attempt_complete( self, attempt: CompletedAttemptMetric, op: ActiveOperationMetric ) -> None: From c1cc24dcd031bdd145c01788be66bd1c11f0c65b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 29 Sep 2025 23:35:15 -0700 Subject: [PATCH 56/60] fixed tests --- google/cloud/bigtable/data/_async/client.py | 8 +- .../bigtable/data/_sync_autogen/client.py | 8 +- tests/unit/data/_async/test_client.py | 2 - .../data/_async/test_metrics_interceptor.py | 189 ++---------------- tests/unit/data/_metrics/test_data_model.py | 21 -- .../data/_metrics/test_metrics_controller.py | 28 +-- tests/unit/data/_sync_autogen/test_client.py | 2 - .../_sync_autogen/test_metrics_interceptor.py | 178 ++--------------- 8 files changed, 46 insertions(+), 390 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index d551f5b7a..ace060a50 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -970,13 +970,7 @@ def __init__( default_retryable_errors or () ) - self._metrics = BigtableClientSideMetricsController( - client._metrics_interceptor, - project_id=self.client.project, - instance_id=instance_id, - table_id=table_id, - app_profile_id=app_profile_id, - ) + self._metrics = BigtableClientSideMetricsController() try: self._register_instance_future = CrossSync.create_task( diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index bcd13c068..28a11b91e 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -757,13 +757,7 @@ def __init__( self.default_retryable_errors: Sequence[type[Exception]] = ( default_retryable_errors or () ) - self._metrics = BigtableClientSideMetricsController( - client._metrics_interceptor, - project_id=self.client.project, - instance_id=instance_id, - table_id=table_id, - app_profile_id=app_profile_id, - ) + self._metrics = BigtableClientSideMetricsController() try: self._register_instance_future = CrossSync._Sync_Impl.create_task( self.client._register_instance, diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 48ba94ba0..0cb19f4fb 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1205,7 +1205,6 @@ async def test_ctor(self): assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(table)} assert isinstance(table._metrics, BigtableClientSideMetricsController) - assert table._metrics.interceptor == client._metrics_interceptor assert table.default_operation_timeout == expected_operation_timeout assert table.default_attempt_timeout == expected_attempt_timeout assert ( @@ -1547,7 +1546,6 @@ async def test_ctor(self): assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(view)} assert isinstance(view._metrics, BigtableClientSideMetricsController) - assert view._metrics.interceptor == client._metrics_interceptor assert view.default_operation_timeout == expected_operation_timeout assert view.default_attempt_timeout == expected_attempt_timeout assert ( diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py index caa9bbb46..1593b8c99 100644 --- a/tests/unit/data/_async/test_metrics_interceptor.py +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -16,6 +16,7 @@ from grpc import RpcError from grpc import ClientCallDetails +from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric from google.cloud.bigtable.data._metrics.data_model import OperationState from google.cloud.bigtable.data._cross_sync import CrossSync @@ -69,94 +70,9 @@ def _get_target_class(): def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) - def test_ctor(self): - instance = self._make_one() - assert instance.operation_map == {} - - def test_register_operation(self): - """ - adding a new operation should register it in operation_map - """ - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric - from google.cloud.bigtable.data._metrics.data_model import OperationType - - instance = self._make_one() - op = ActiveOperationMetric(OperationType.READ_ROWS) - instance.register_operation(op) - assert instance.operation_map[op.uuid] == op - assert instance in op.handlers - - def test_on_operation_comple_mock(self): - """ - completing or cancelling an operation should call on_operation_complete on interceptor - """ - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric - from google.cloud.bigtable.data._metrics.data_model import OperationType - - instance = self._make_one() - instance.on_operation_complete = mock.Mock() - op = ActiveOperationMetric(OperationType.READ_ROWS) - instance.register_operation(op) - op.end_with_success() - assert instance.on_operation_complete.call_count == 1 - op.cancel() - assert instance.on_operation_complete.call_count == 2 - - def test_on_operation_complete(self): - """ - completing an operation should remove it from the operation map - """ - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric - from google.cloud.bigtable.data._metrics.data_model import OperationType - - instance = self._make_one() - op = ActiveOperationMetric(OperationType.READ_ROWS) - instance.register_operation(op) - op.end_with_success() - instance.on_operation_complete(op) - assert op.uuid not in instance.operation_map - - def test_on_operation_cancelled(self): - """ - completing an operation should remove it from the operation map - """ - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric - from google.cloud.bigtable.data._metrics.data_model import OperationType - - instance = self._make_one() - op = ActiveOperationMetric(OperationType.READ_ROWS) - instance.register_operation(op) - op.cancel() - assert op.uuid not in instance.operation_map - - @CrossSync.pytest - async def test_strip_operation_id_metadata(self): - """ - After operation id is detected in metadata, the field should be stripped out before calling continuation - """ - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - - instance = self._make_one() - op = mock.Mock() - op.uuid = "test-uuid" - op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op - continuation = CrossSync.Mock() - details = ClientCallDetails() - details.metadata = [ - (OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid), - ("other_key", "other_value"), - ] - await instance.intercept_unary_unary(continuation, details, mock.Mock()) - assert details.metadata == [("other_key", "other_value")] - assert continuation.call_count == 1 - assert continuation.call_args[0][0].metadata == [("other_key", "other_value")] - @CrossSync.pytest async def test_unary_unary_interceptor_op_not_found(self): - """Test that interceptor call cuntinuation if op is not found""" + """Test that interceptor call continuation if op is not found""" instance = self._make_one() continuation = CrossSync.Mock() details = ClientCallDetails() @@ -168,107 +84,84 @@ async def test_unary_unary_interceptor_op_not_found(self): @CrossSync.pytest async def test_unary_unary_interceptor_success(self): """Test that interceptor handles successful unary-unary calls""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) continuation = CrossSync.Mock() call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() result = await instance.intercept_unary_unary(continuation, details, request) assert result == call continuation.assert_called_once_with(details, request) - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) op.end_attempt_with_status.assert_not_called() @CrossSync.pytest async def test_unary_unary_interceptor_failure(self): """test a failed RpcError with metadata""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = RpcError("test") exc.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) exc.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) continuation = CrossSync.Mock(side_effect=exc) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: await instance.intercept_unary_unary(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) - op.end_attempt_with_status.assert_called_once_with(exc) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) @CrossSync.pytest async def test_unary_unary_interceptor_failure_no_metadata(self): """test with RpcError without without metadata attached""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = RpcError("test") continuation = CrossSync.Mock(side_effect=exc) call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: await instance.intercept_unary_unary(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) @CrossSync.pytest async def test_unary_unary_interceptor_failure_generic(self): """test generic exception""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = ValueError("test") continuation = CrossSync.Mock(side_effect=exc) call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(ValueError) as e: await instance.intercept_unary_unary(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) @CrossSync.pytest async def test_unary_stream_interceptor_op_not_found(self): @@ -284,39 +177,32 @@ async def test_unary_stream_interceptor_op_not_found(self): @CrossSync.pytest async def test_unary_stream_interceptor_success(self): """Test that interceptor handles successful unary-stream calls""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) continuation = CrossSync.Mock(return_value=_make_mock_stream_call([1, 2])) call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() wrapper = await instance.intercept_unary_stream(continuation, details, request) results = [val async for val in wrapper] assert results == [1, 2] continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) op.end_attempt_with_status.assert_not_called() @CrossSync.pytest async def test_unary_stream_interceptor_failure_mid_stream(self): """Test that interceptor handles failures mid-stream""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) + from grpc.aio import AioRpcError, Metadata instance = self._make_one() op = mock.Mock() @@ -324,38 +210,29 @@ async def test_unary_stream_interceptor_failure_mid_stream(self): op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op - exc = ValueError("test") + ActiveOperationMetric._active_operation_context.set(op) + exc = AioRpcError(0, Metadata(), Metadata(("a", "b"), ("c", "d"))) continuation = CrossSync.Mock(return_value=_make_mock_stream_call([1], exc=exc)) - call = continuation.return_value - call.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) - call.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() wrapper = await instance.intercept_unary_stream(continuation, details, request) - with pytest.raises(ValueError) as e: + with pytest.raises(AioRpcError) as e: [val async for val in wrapper] assert e.value == exc continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) - op.end_attempt_with_status.assert_called_once_with(exc) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) @CrossSync.pytest async def test_unary_stream_interceptor_failure_start_stream(self): """Test that interceptor handles failures at start of stream with RpcError with metadata""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = RpcError("test") exc.trailing_metadata = CrossSync.Mock(return_value=[("a", "b")]) exc.initial_metadata = CrossSync.Mock(return_value=[("c", "d")]) @@ -363,36 +240,29 @@ async def test_unary_stream_interceptor_failure_start_stream(self): continuation = CrossSync.Mock() continuation.side_effect = exc details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: await instance.intercept_unary_stream(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) - op.end_attempt_with_status.assert_called_once_with(exc) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) @CrossSync.pytest async def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): """Test that interceptor handles failures at start of stream with RpcError with no metadata""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = RpcError("test") continuation = CrossSync.Mock() continuation.side_effect = exc details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: await instance.intercept_unary_stream(continuation, details, request) @@ -400,28 +270,22 @@ async def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) @CrossSync.pytest async def test_unary_stream_interceptor_failure_start_stream_generic(self): """Test that interceptor handles failures at start of stream with generic exception""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = ValueError("test") continuation = CrossSync.Mock() continuation.side_effect = exc details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(ValueError) as e: await instance.intercept_unary_stream(continuation, details, request) @@ -429,7 +293,6 @@ async def test_unary_stream_interceptor_failure_start_stream_generic(self): continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) @CrossSync.pytest @pytest.mark.parametrize( @@ -437,21 +300,16 @@ async def test_unary_stream_interceptor_failure_start_stream_generic(self): ) async def test_unary_unary_interceptor_start_operation(self, initial_state): """if called with a newly created operation, it should be started""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = initial_state - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) continuation = CrossSync.Mock() call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[]) call.initial_metadata = CrossSync.Mock(return_value=[]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() await instance.intercept_unary_unary(continuation, details, request) op.start_attempt.assert_called_once() @@ -462,22 +320,17 @@ async def test_unary_unary_interceptor_start_operation(self, initial_state): ) async def test_unary_stream_interceptor_start_operation(self, initial_state): """if called with a newly created operation, it should be started""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = initial_state - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) continuation = CrossSync.Mock(return_value=_make_mock_stream_call([1, 2])) call = continuation.return_value call.trailing_metadata = CrossSync.Mock(return_value=[]) call.initial_metadata = CrossSync.Mock(return_value=[]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() await instance.intercept_unary_stream(continuation, details, request) op.start_attempt.assert_called_once() diff --git a/tests/unit/data/_metrics/test_data_model.py b/tests/unit/data/_metrics/test_data_model.py index 7d9b6671f..42aa96093 100644 --- a/tests/unit/data/_metrics/test_data_model.py +++ b/tests/unit/data/_metrics/test_data_model.py @@ -535,27 +535,6 @@ def test_end_with_status_w_exception(self): final_op = handlers[0].on_operation_complete.call_args[0][0] assert final_op.final_status == expected_status - def test_interceptor_metadata(self): - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - - metric = self._make_one(mock.Mock()) - key, value = metric.interceptor_metadata - assert key == OPERATION_INTERCEPTOR_METADATA_KEY - assert value == metric.uuid - - def test_cancel(self): - """ - cancel should call on_operation_cancelled on handlers - """ - handlers = [mock.Mock(), mock.Mock()] - metric = self._make_one(mock.Mock(), handlers=handlers) - metric.cancel() - for h in handlers: - assert h.on_operation_cancelled.call_count == 1 - assert h.on_operation_cancelled.call_args[0][0] == metric - def test_end_with_status_with_default_cluster_zone(self): """ ending the operation should use default cluster and zone if not set diff --git a/tests/unit/data/_metrics/test_metrics_controller.py b/tests/unit/data/_metrics/test_metrics_controller.py index 701af737b..66ebe56f6 100644 --- a/tests/unit/data/_metrics/test_metrics_controller.py +++ b/tests/unit/data/_metrics/test_metrics_controller.py @@ -21,19 +21,13 @@ def _make_one(self, *args, **kwargs): BigtableClientSideMetricsController, ) - # add mock interceptor if called with no arguments - if not args and "interceptor" not in kwargs: - args = [mock.Mock()] - return BigtableClientSideMetricsController(*args, **kwargs) def test_ctor_defaults(self): """ should create instance with GCP Exporter handler by default """ - expected_interceptor = object() - instance = self._make_one(expected_interceptor) - assert instance.interceptor == expected_interceptor + instance = self._make_one() assert len(instance.handlers) == 0 def ctor_custom_handlers(self): @@ -92,22 +86,4 @@ def test_create_operation(self): assert op.is_streaming is expected_is_streaming assert op.zone is expected_zone assert len(op.handlers) == 1 - assert op.handlers[0] is handler - - def test_create_operation_registers_interceptor(self): - """ - creating an operation should link the operation with the controller's interceptor, - and add the interceptor as a handler to the operation - """ - from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( - BigtableMetricsInterceptor, - ) - - custom_handler = object() - controller = self._make_one( - BigtableMetricsInterceptor(), handlers=[custom_handler] - ) - op = controller.create_operation(object()) - assert custom_handler in op.handlers - assert op.uuid in controller.interceptor.operation_map - assert controller.interceptor.operation_map[op.uuid] == op + assert op.handlers[0] is handler \ No newline at end of file diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index afe741e57..47937a767 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -982,7 +982,6 @@ def test_ctor(self): assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(table)} assert isinstance(table._metrics, BigtableClientSideMetricsController) - assert table._metrics.interceptor == client._metrics_interceptor assert table.default_operation_timeout == expected_operation_timeout assert table.default_attempt_timeout == expected_attempt_timeout assert ( @@ -1251,7 +1250,6 @@ def test_ctor(self): assert instance_key in client._active_instances assert client._instance_owners[instance_key] == {id(view)} assert isinstance(view._metrics, BigtableClientSideMetricsController) - assert view._metrics.interceptor == client._metrics_interceptor assert view.default_operation_timeout == expected_operation_timeout assert view.default_attempt_timeout == expected_attempt_timeout assert ( diff --git a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py index 283814e27..c4efcc5b9 100644 --- a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py +++ b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py @@ -18,6 +18,7 @@ import pytest from grpc import RpcError from grpc import ClientCallDetails +from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric from google.cloud.bigtable.data._metrics.data_model import OperationState from google.cloud.bigtable.data._cross_sync import CrossSync @@ -52,82 +53,8 @@ def _get_target_class(): def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) - def test_ctor(self): - instance = self._make_one() - assert instance.operation_map == {} - - def test_register_operation(self): - """adding a new operation should register it in operation_map""" - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric - from google.cloud.bigtable.data._metrics.data_model import OperationType - - instance = self._make_one() - op = ActiveOperationMetric(OperationType.READ_ROWS) - instance.register_operation(op) - assert instance.operation_map[op.uuid] == op - assert instance in op.handlers - - def test_on_operation_comple_mock(self): - """completing or cancelling an operation should call on_operation_complete on interceptor""" - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric - from google.cloud.bigtable.data._metrics.data_model import OperationType - - instance = self._make_one() - instance.on_operation_complete = mock.Mock() - op = ActiveOperationMetric(OperationType.READ_ROWS) - instance.register_operation(op) - op.end_with_success() - assert instance.on_operation_complete.call_count == 1 - op.cancel() - assert instance.on_operation_complete.call_count == 2 - - def test_on_operation_complete(self): - """completing an operation should remove it from the operation map""" - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric - from google.cloud.bigtable.data._metrics.data_model import OperationType - - instance = self._make_one() - op = ActiveOperationMetric(OperationType.READ_ROWS) - instance.register_operation(op) - op.end_with_success() - instance.on_operation_complete(op) - assert op.uuid not in instance.operation_map - - def test_on_operation_cancelled(self): - """completing an operation should remove it from the operation map""" - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric - from google.cloud.bigtable.data._metrics.data_model import OperationType - - instance = self._make_one() - op = ActiveOperationMetric(OperationType.READ_ROWS) - instance.register_operation(op) - op.cancel() - assert op.uuid not in instance.operation_map - - def test_strip_operation_id_metadata(self): - """After operation id is detected in metadata, the field should be stripped out before calling continuation""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - - instance = self._make_one() - op = mock.Mock() - op.uuid = "test-uuid" - op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op - continuation = CrossSync._Sync_Impl.Mock() - details = ClientCallDetails() - details.metadata = [ - (OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid), - ("other_key", "other_value"), - ] - instance.intercept_unary_unary(continuation, details, mock.Mock()) - assert details.metadata == [("other_key", "other_value")] - assert continuation.call_count == 1 - assert continuation.call_args[0][0].metadata == [("other_key", "other_value")] - def test_unary_unary_interceptor_op_not_found(self): - """Test that interceptor call cuntinuation if op is not found""" + """Test that interceptor call continuation if op is not found""" instance = self._make_one() continuation = CrossSync._Sync_Impl.Mock() details = ClientCallDetails() @@ -138,104 +65,81 @@ def test_unary_unary_interceptor_op_not_found(self): def test_unary_unary_interceptor_success(self): """Test that interceptor handles successful unary-unary calls""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) continuation = CrossSync._Sync_Impl.Mock() call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() result = instance.intercept_unary_unary(continuation, details, request) assert result == call continuation.assert_called_once_with(details, request) - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) op.end_attempt_with_status.assert_not_called() def test_unary_unary_interceptor_failure(self): """test a failed RpcError with metadata""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = RpcError("test") exc.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) exc.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: instance.intercept_unary_unary(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) - op.end_attempt_with_status.assert_called_once_with(exc) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) def test_unary_unary_interceptor_failure_no_metadata(self): """test with RpcError without without metadata attached""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = RpcError("test") continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: instance.intercept_unary_unary(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) def test_unary_unary_interceptor_failure_generic(self): """test generic exception""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = ValueError("test") continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(ValueError) as e: instance.intercept_unary_unary(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) def test_unary_stream_interceptor_op_not_found(self): """Test that interceptor calls continuation if op is not found""" @@ -249,17 +153,13 @@ def test_unary_stream_interceptor_op_not_found(self): def test_unary_stream_interceptor_success(self): """Test that interceptor handles successful unary-stream calls""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) continuation = CrossSync._Sync_Impl.Mock( return_value=_make_mock_stream_call([1, 2]) ) @@ -267,21 +167,18 @@ def test_unary_stream_interceptor_success(self): call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() wrapper = instance.intercept_unary_stream(continuation, details, request) results = [val for val in wrapper] assert results == [1, 2] continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) op.end_attempt_with_status.assert_not_called() def test_unary_stream_interceptor_failure_mid_stream(self): """Test that interceptor handles failures mid-stream""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) + from grpc.aio import AioRpcError, Metadata instance = self._make_one() op = mock.Mock() @@ -289,73 +186,57 @@ def test_unary_stream_interceptor_failure_mid_stream(self): op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op - exc = ValueError("test") + ActiveOperationMetric._active_operation_context.set(op) + exc = AioRpcError(0, Metadata(), Metadata(("a", "b"), ("c", "d"))) continuation = CrossSync._Sync_Impl.Mock( return_value=_make_mock_stream_call([1], exc=exc) ) - call = continuation.return_value - call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) - call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() wrapper = instance.intercept_unary_stream(continuation, details, request) - with pytest.raises(ValueError) as e: + with pytest.raises(AioRpcError) as e: [val for val in wrapper] assert e.value == exc continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) - op.end_attempt_with_status.assert_called_once_with(exc) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) def test_unary_stream_interceptor_failure_start_stream(self): """Test that interceptor handles failures at start of stream with RpcError with metadata""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = RpcError("test") exc.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[("a", "b")]) exc.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[("c", "d")]) continuation = CrossSync._Sync_Impl.Mock() continuation.side_effect = exc details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: instance.intercept_unary_stream(continuation, details, request) assert e.value == exc continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None - op.add_response_metadata.assert_called_once_with([("a", "b"), ("c", "d")]) - op.end_attempt_with_status.assert_called_once_with(exc) + op.add_response_metadata.assert_called_once_with({"a": "b", "c": "d"}) def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): """Test that interceptor handles failures at start of stream with RpcError with no metadata""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = RpcError("test") continuation = CrossSync._Sync_Impl.Mock() continuation.side_effect = exc details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(RpcError) as e: instance.intercept_unary_stream(continuation, details, request) @@ -363,26 +244,20 @@ def test_unary_stream_interceptor_failure_start_stream_no_metadata(self): continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) def test_unary_stream_interceptor_failure_start_stream_generic(self): """Test that interceptor handles failures at start of stream with generic exception""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = OperationState.ACTIVE_ATTEMPT op.start_time_ns = 0 op.first_response_latency = None - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) exc = ValueError("test") continuation = CrossSync._Sync_Impl.Mock() continuation.side_effect = exc details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() with pytest.raises(ValueError) as e: instance.intercept_unary_stream(continuation, details, request) @@ -390,28 +265,22 @@ def test_unary_stream_interceptor_failure_start_stream_generic(self): continuation.assert_called_once_with(details, request) assert op.first_response_latency_ns is not None op.add_response_metadata.assert_not_called() - op.end_attempt_with_status.assert_called_once_with(exc) @pytest.mark.parametrize( "initial_state", [OperationState.CREATED, OperationState.BETWEEN_ATTEMPTS] ) def test_unary_unary_interceptor_start_operation(self, initial_state): """if called with a newly created operation, it should be started""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = initial_state - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) continuation = CrossSync._Sync_Impl.Mock() call = continuation.return_value call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() instance.intercept_unary_unary(continuation, details, request) op.start_attempt.assert_called_once() @@ -421,15 +290,11 @@ def test_unary_unary_interceptor_start_operation(self, initial_state): ) def test_unary_stream_interceptor_start_operation(self, initial_state): """if called with a newly created operation, it should be started""" - from google.cloud.bigtable.data._metrics.data_model import ( - OPERATION_INTERCEPTOR_METADATA_KEY, - ) - instance = self._make_one() op = mock.Mock() op.uuid = "test-uuid" op.state = initial_state - instance.operation_map[op.uuid] = op + ActiveOperationMetric._active_operation_context.set(op) continuation = CrossSync._Sync_Impl.Mock( return_value=_make_mock_stream_call([1, 2]) ) @@ -437,7 +302,6 @@ def test_unary_stream_interceptor_start_operation(self, initial_state): call.trailing_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) call.initial_metadata = CrossSync._Sync_Impl.Mock(return_value=[]) details = ClientCallDetails() - details.metadata = [(OPERATION_INTERCEPTOR_METADATA_KEY, op.uuid)] request = mock.Mock() instance.intercept_unary_stream(continuation, details, request) op.start_attempt.assert_called_once() From ed9d3cf4ac5b34f8c496fcfca19a8c0a8c54c73e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 30 Sep 2025 19:18:15 -0700 Subject: [PATCH 57/60] fixed lint --- google/cloud/bigtable/data/_metrics/metrics_controller.py | 2 -- google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py | 2 +- tests/unit/data/_metrics/test_metrics_controller.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_metrics/metrics_controller.py b/google/cloud/bigtable/data/_metrics/metrics_controller.py index 8e7eb373d..a3ea65e82 100644 --- a/google/cloud/bigtable/data/_metrics/metrics_controller.py +++ b/google/cloud/bigtable/data/_metrics/metrics_controller.py @@ -13,8 +13,6 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler from google.cloud.bigtable.data._metrics.data_model import OperationType diff --git a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py index 029b8e6a9..dcc17e591 100644 --- a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py +++ b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py @@ -124,4 +124,4 @@ def _streaming_generator_wrapper(operation, call): if call is not None: metadata = _get_metadata(encountered_exc or call) if metadata is not None: - operation.add_response_metadata(metadata) \ No newline at end of file + operation.add_response_metadata(metadata) diff --git a/tests/unit/data/_metrics/test_metrics_controller.py b/tests/unit/data/_metrics/test_metrics_controller.py index 66ebe56f6..7fdbaef07 100644 --- a/tests/unit/data/_metrics/test_metrics_controller.py +++ b/tests/unit/data/_metrics/test_metrics_controller.py @@ -86,4 +86,4 @@ def test_create_operation(self): assert op.is_streaming is expected_is_streaming assert op.zone is expected_zone assert len(op.handlers) == 1 - assert op.handlers[0] is handler \ No newline at end of file + assert op.handlers[0] is handler From bc6036edff3a3ff26b1025bb92041978eea5718f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 2 Oct 2025 15:08:04 -0700 Subject: [PATCH 58/60] added close to metric spec --- google/cloud/bigtable/data/_async/client.py | 1 + .../cloud/bigtable/data/_metrics/handlers/_base.py | 3 +++ .../bigtable/data/_metrics/metrics_controller.py | 7 +++++++ google/cloud/bigtable/data/_sync_autogen/client.py | 1 + tests/unit/data/_async/test_client.py | 10 ++++++++++ tests/unit/data/_sync_autogen/test_client.py | 13 +++++++++++++ 6 files changed, 35 insertions(+) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index ace060a50..f8c7b287d 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -1685,6 +1685,7 @@ async def close(self): """ Called to close the Table instance and release any resources held by it. """ + self._metrics.close() if self._register_instance_future: self._register_instance_future.cancel() await self.client._remove_instance_registration(self.instance_id, self) diff --git a/google/cloud/bigtable/data/_metrics/handlers/_base.py b/google/cloud/bigtable/data/_metrics/handlers/_base.py index 72f5aa550..bfd1dffab 100644 --- a/google/cloud/bigtable/data/_metrics/handlers/_base.py +++ b/google/cloud/bigtable/data/_metrics/handlers/_base.py @@ -33,3 +33,6 @@ def on_attempt_complete( self, attempt: CompletedAttemptMetric, op: ActiveOperationMetric ) -> None: pass + + def close(self): + pass \ No newline at end of file diff --git a/google/cloud/bigtable/data/_metrics/metrics_controller.py b/google/cloud/bigtable/data/_metrics/metrics_controller.py index a3ea65e82..25a802337 100644 --- a/google/cloud/bigtable/data/_metrics/metrics_controller.py +++ b/google/cloud/bigtable/data/_metrics/metrics_controller.py @@ -54,3 +54,10 @@ def create_operation( Creates a new operation and registers it with the subscribed handlers. """ return ActiveOperationMetric(op_type, **kwargs, handlers=self.handlers) + + def close(self): + """ + Close all handlers. + """ + for handler in self.handlers: + handler.close() \ No newline at end of file diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 28a11b91e..86993cf56 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -1414,6 +1414,7 @@ def read_modify_write_row( def close(self): """Called to close the Table instance and release any resources held by it.""" + self._metrics.close() if self._register_instance_future: self._register_instance_future.cancel() self.client._remove_instance_registration(self.instance_id, self) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 0cb19f4fb..6a40d2188 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1464,6 +1464,16 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ # empty app_profile_id should send empty string assert "app_profile_id=" in routing_str + @CrossSync.pytest + async def test_close(self): + client = self._make_client() + table = self._make_one(client) + with mock.patch.object(table._metrics, "close", mock.Mock()) as metric_close_mock: + with mock.patch.object(client, "_remove_instance_registration") as remove_mock: + await table.close() + remove_mock.assert_called_once_with(table.instance_id, table) + metric_close_mock.assert_called_once() + @CrossSync.convert_class( "TestAuthorizedView", add_mapping_for_name="TestAuthorizedView" diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index 47937a767..42f5388ee 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -1173,6 +1173,19 @@ def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): else: assert "app_profile_id=" in routing_str + def test_close(self): + client = self._make_client() + table = self._make_one(client) + with mock.patch.object( + table._metrics, "close", mock.Mock() + ) as metric_close_mock: + with mock.patch.object( + client, "_remove_instance_registration" + ) as remove_mock: + table.close() + remove_mock.assert_called_once_with(table.instance_id, table) + metric_close_mock.assert_called_once() + @CrossSync._Sync_Impl.add_mapping_decorator("TestAuthorizedView") class TestAuthorizedView(CrossSync._Sync_Impl.TestTable): From 18ec330a7803a23915b9d36af5a6b68a429da975 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 2 Oct 2025 15:10:27 -0700 Subject: [PATCH 59/60] fixed lint --- google/cloud/bigtable/data/_metrics/handlers/_base.py | 2 +- google/cloud/bigtable/data/_metrics/metrics_controller.py | 2 +- tests/unit/data/_async/test_client.py | 8 ++++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_metrics/handlers/_base.py b/google/cloud/bigtable/data/_metrics/handlers/_base.py index bfd1dffab..884091fdd 100644 --- a/google/cloud/bigtable/data/_metrics/handlers/_base.py +++ b/google/cloud/bigtable/data/_metrics/handlers/_base.py @@ -35,4 +35,4 @@ def on_attempt_complete( pass def close(self): - pass \ No newline at end of file + pass diff --git a/google/cloud/bigtable/data/_metrics/metrics_controller.py b/google/cloud/bigtable/data/_metrics/metrics_controller.py index 25a802337..e9815f201 100644 --- a/google/cloud/bigtable/data/_metrics/metrics_controller.py +++ b/google/cloud/bigtable/data/_metrics/metrics_controller.py @@ -60,4 +60,4 @@ def close(self): Close all handlers. """ for handler in self.handlers: - handler.close() \ No newline at end of file + handler.close() diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 6a40d2188..2cae7a08c 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1468,8 +1468,12 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ async def test_close(self): client = self._make_client() table = self._make_one(client) - with mock.patch.object(table._metrics, "close", mock.Mock()) as metric_close_mock: - with mock.patch.object(client, "_remove_instance_registration") as remove_mock: + with mock.patch.object( + table._metrics, "close", mock.Mock() + ) as metric_close_mock: + with mock.patch.object( + client, "_remove_instance_registration" + ) as remove_mock: await table.close() remove_mock.assert_called_once_with(table.instance_id, table) metric_close_mock.assert_called_once() From f1be54a9b81883300240ff5c89f3352d7682d949 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 2 Oct 2025 22:32:18 -0700 Subject: [PATCH 60/60] added test --- tests/unit/data/_metrics/test_metrics_controller.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/data/_metrics/test_metrics_controller.py b/tests/unit/data/_metrics/test_metrics_controller.py index 7fdbaef07..2f5eff700 100644 --- a/tests/unit/data/_metrics/test_metrics_controller.py +++ b/tests/unit/data/_metrics/test_metrics_controller.py @@ -87,3 +87,10 @@ def test_create_operation(self): assert op.zone is expected_zone assert len(op.handlers) == 1 assert op.handlers[0] is handler + + def test_close(self): + handlers = [mock.Mock() for _ in range(3)] + controller = self._make_one(handlers=handlers) + controller.close() + for handler in handlers: + handler.close.assert_called_once()