Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,9 @@ def connect(
instance = client.instance(instance_id)
database = None
if database_id:
logger = kwargs.get("logger")
database = instance.database(
database_id, pool=pool, database_role=database_role
database_id, pool=pool, database_role=database_role, logger=logger
)
conn = Connection(instance, database, **kwargs)
if pool is not None:
Expand Down
6 changes: 5 additions & 1 deletion tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import unittest

import grpc
Expand Down Expand Up @@ -170,12 +170,15 @@ class MockServerTestBase(unittest.TestCase):
spanner_service: SpannerServicer = None
database_admin_service: DatabaseAdminServicer = None
port: int = None
logger: logging.Logger = None

def __init__(self, *args, **kwargs):
super(MockServerTestBase, self).__init__(*args, **kwargs)
self._client = None
self._instance = None
self._database = None
self.logger = logging.getLogger("MockServerTestBase")
self.logger.setLevel(logging.WARN)

@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -227,6 +230,7 @@ def database(self) -> Database:
"test-database",
pool=FixedSizePool(size=10),
enable_interceptors_in_tests=True,
logger=self.logger,
)
return self._database

Expand Down
6 changes: 0 additions & 6 deletions tests/mockserver_tests/test_request_id_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,6 @@ def test_database_execute_partitioned_dml_request_id(self):
(1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, exec_sql_seq, 1),
)
]
print(f"Filtered unary segments: {filtered_unary_segments}")
print(f"Want unary segments: {want_unary_segments}")
print(f"Got stream segments: {got_stream_segments}")
print(f"Want stream segments: {want_stream_segments}")
assert all(seg in filtered_unary_segments for seg in want_unary_segments)
assert got_stream_segments == want_stream_segments

Expand Down Expand Up @@ -269,8 +265,6 @@ def test_unary_retryable_error(self):
(1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, exec_sql_seq, 1),
)
]
print(f"Got stream segments: {got_stream_segments}")
print(f"Want stream segments: {want_stream_segments}")
assert got_stream_segments == want_stream_segments

def test_streaming_retryable_error(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/spanner_dbapi/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_w_implicit(self, mock_client):

self.assertIs(connection.database, database)
instance.database.assert_called_once_with(
DATABASE, pool=None, database_role=None
DATABASE, pool=None, database_role=None, logger=None
)
# Database constructs its own pool
self.assertIsNotNone(connection.database._pool)
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_w_explicit(self, mock_client):

self.assertIs(connection.database, database)
instance.database.assert_called_once_with(
DATABASE, pool=pool, database_role=role
DATABASE, pool=pool, database_role=role, logger=None
)

def test_w_credential_file_path(self, mock_client):
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,8 +888,9 @@ def database(
pool=None,
database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL,
database_role=None,
logger=None,
):
return _Database(database_id, pool, database_dialect, database_role)
return _Database(database_id, pool, database_dialect, database_role, logger)


class _Database(object):
Expand All @@ -899,8 +900,10 @@ def __init__(
pool=None,
database_dialect=DatabaseDialect.GOOGLE_STANDARD_SQL,
database_role=None,
logger=None,
):
self.name = database_id
self.pool = pool
self.database_dialect = database_dialect
self.database_role = database_role
self.logger = logger