diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 1a2b117e4c..db18f44067 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -819,8 +819,9 @@ def connect( instance = client.instance(instance_id) database = None if database_id: + logger = kwargs.get("logger") database = instance.database( - database_id, pool=pool, database_role=database_role + database_id, pool=pool, database_role=database_role, logger=logger ) conn = Connection(instance, database, **kwargs) if pool is not None: diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 443b75ada7..117b649e1b 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging import unittest import grpc @@ -170,12 +170,15 @@ class MockServerTestBase(unittest.TestCase): spanner_service: SpannerServicer = None database_admin_service: DatabaseAdminServicer = None port: int = None + logger: logging.Logger = None def __init__(self, *args, **kwargs): super(MockServerTestBase, self).__init__(*args, **kwargs) self._client = None self._instance = None self._database = None + self.logger = logging.getLogger("MockServerTestBase") + self.logger.setLevel(logging.WARN) @classmethod def setup_class(cls): @@ -227,6 +230,7 @@ def database(self) -> Database: "test-database", pool=FixedSizePool(size=10), enable_interceptors_in_tests=True, + logger=self.logger, ) return self._database diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index 413e0f6514..055d9d97b5 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -227,10 +227,6 @@ def test_database_execute_partitioned_dml_request_id(self): (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, exec_sql_seq, 1), ) ] - print(f"Filtered unary segments: {filtered_unary_segments}") - print(f"Want unary segments: {want_unary_segments}") - print(f"Got stream segments: {got_stream_segments}") - print(f"Want stream segments: {want_stream_segments}") assert all(seg in filtered_unary_segments for seg in want_unary_segments) assert got_stream_segments == want_stream_segments @@ -269,8 +265,6 @@ def test_unary_retryable_error(self): (1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, exec_sql_seq, 1), ) ] - print(f"Got stream segments: {got_stream_segments}") - print(f"Want stream segments: {want_stream_segments}") assert got_stream_segments == want_stream_segments def test_streaming_retryable_error(self): diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 7f4fb4c7f3..5fd2b74a17 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -59,7 +59,7 @@ def test_w_implicit(self, mock_client): self.assertIs(connection.database, database) instance.database.assert_called_once_with( - DATABASE, pool=None, database_role=None + DATABASE, pool=None, database_role=None, logger=None ) # Database constructs its own pool self.assertIsNotNone(connection.database._pool) @@ -107,7 +107,7 @@ def test_w_explicit(self, mock_client): self.assertIs(connection.database, database) instance.database.assert_called_once_with( - DATABASE, pool=pool, database_role=role + DATABASE, pool=pool, database_role=role, logger=None ) def test_w_credential_file_path(self, mock_client): diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 0bfab5bab9..6e8159425f 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -888,8 +888,9 @@ def database( pool=None, database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, database_role=None, + logger=None, ): - return _Database(database_id, pool, database_dialect, database_role) + return _Database(database_id, pool, database_dialect, database_role, logger) class _Database(object): @@ -899,8 +900,10 @@ def __init__( pool=None, database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, database_role=None, + logger=None, ): self.name = database_id self.pool = pool self.database_dialect = database_dialect self.database_role = database_role + self.logger = logger