1515import unittest
1616
1717from google .cloud .spanner_admin_database_v1 .types import spanner_database_admin
18+ from google .cloud .spanner_dbapi import Connection
19+ from google .cloud .spanner_dbapi .parsed_statement import AutocommitDmlMode
1820from google .cloud .spanner_v1 .testing .mock_database_admin import DatabaseAdminServicer
1921from google .cloud .spanner_v1 .testing .mock_spanner import (
2022 start_mock_server ,
2931 FixedSizePool ,
3032 BatchCreateSessionsRequest ,
3133 ExecuteSqlRequest ,
34+ BeginTransactionRequest ,
35+ TransactionOptions ,
3236)
3337from google .cloud .spanner_v1 .database import Database
3438from google .cloud .spanner_v1 .instance import Instance
@@ -62,6 +66,10 @@ def tearDownClass(cls):
6266 TestBasics .server .stop (grace = None )
6367 TestBasics .server = None
6468
69+ def teardown_method (self , * args , ** kwargs ):
70+ TestBasics .spanner_service .clear_requests ()
71+ TestBasics .database_admin_service .clear_requests ()
72+
6573 def _add_select1_result (self ):
6674 result = result_set .ResultSet (
6775 dict (
@@ -88,6 +96,19 @@ def _add_select1_result(self):
8896 result .rows .extend (["1" ])
8997 TestBasics .spanner_service .mock_spanner .add_result ("select 1" , result )
9098
99+ def add_update_count (
100+ self ,
101+ sql : str ,
102+ count : int ,
103+ dml_mode : AutocommitDmlMode = AutocommitDmlMode .TRANSACTIONAL ,
104+ ):
105+ if dml_mode == AutocommitDmlMode .PARTITIONED_NON_ATOMIC :
106+ stats = dict (row_count_lower_bound = count )
107+ else :
108+ stats = dict (row_count_exact = count )
109+ result = result_set .ResultSet (dict (stats = result_set .ResultSetStats (stats )))
110+ TestBasics .spanner_service .mock_spanner .add_result (sql , result )
111+
91112 @property
92113 def client (self ) -> Client :
93114 if self ._client is None :
@@ -145,3 +166,27 @@ def test_create_table(self):
145166 )
146167 operation = database_admin_api .update_database_ddl (request )
147168 operation .result (1 )
169+
170+ # TODO: Move this to a separate class once the mock server test setup has
171+ # been re-factored to use a base class for the boiler plate code.
172+ def test_dbapi_partitioned_dml (self ):
173+ sql = "UPDATE singers SET foo='bar' WHERE active = true"
174+ self .add_update_count (sql , 100 , AutocommitDmlMode .PARTITIONED_NON_ATOMIC )
175+ connection = Connection (self .instance , self .database )
176+ connection .autocommit = True
177+ connection .set_autocommit_dml_mode (AutocommitDmlMode .PARTITIONED_NON_ATOMIC )
178+ with connection .cursor () as cursor :
179+ # Note: SQLAlchemy uses [] as the list of parameters for statements
180+ # with no parameters.
181+ cursor .execute (sql , [])
182+ self .assertEqual (100 , cursor .rowcount )
183+
184+ requests = self .spanner_service .requests
185+ self .assertEqual (3 , len (requests ), msg = requests )
186+ self .assertTrue (isinstance (requests [0 ], BatchCreateSessionsRequest ))
187+ self .assertTrue (isinstance (requests [1 ], BeginTransactionRequest ))
188+ self .assertTrue (isinstance (requests [2 ], ExecuteSqlRequest ))
189+ begin_request : BeginTransactionRequest = requests [1 ]
190+ self .assertEqual (
191+ TransactionOptions (dict (partitioned_dml = {})), begin_request .options
192+ )
0 commit comments