|
41 | 41 | SpannerServicer, |
42 | 42 | start_mock_server, |
43 | 43 | ) |
| 44 | +from tests._helpers import is_multiplexed_enabled |
44 | 45 |
|
45 | 46 |
|
46 | 47 | # Creates an aborted status with the smallest possible retry delay. |
@@ -228,3 +229,109 @@ def database(self) -> Database: |
228 | 229 | enable_interceptors_in_tests=True, |
229 | 230 | ) |
230 | 231 | return self._database |
| 232 | + |
| 233 | + def assert_requests_sequence( |
| 234 | + self, |
| 235 | + requests, |
| 236 | + expected_types, |
| 237 | + transaction_type, |
| 238 | + allow_multiple_batch_create=True, |
| 239 | + ): |
| 240 | + """Assert that the requests sequence matches the expected types, accounting for multiplexed sessions and retries. |
| 241 | +
|
| 242 | + Args: |
| 243 | + requests: List of requests from spanner_service.requests |
| 244 | + expected_types: List of expected request types (excluding session creation requests) |
| 245 | + transaction_type: TransactionType enum value to check multiplexed session status |
| 246 | + allow_multiple_batch_create: If True, skip all leading BatchCreateSessionsRequest and one optional CreateSessionRequest |
| 247 | + """ |
| 248 | + from google.cloud.spanner_v1 import ( |
| 249 | + BatchCreateSessionsRequest, |
| 250 | + CreateSessionRequest, |
| 251 | + ) |
| 252 | + |
| 253 | + mux_enabled = is_multiplexed_enabled(transaction_type) |
| 254 | + idx = 0 |
| 255 | + # Skip all leading BatchCreateSessionsRequest (for retries) |
| 256 | + if allow_multiple_batch_create: |
| 257 | + while idx < len(requests) and isinstance( |
| 258 | + requests[idx], BatchCreateSessionsRequest |
| 259 | + ): |
| 260 | + idx += 1 |
| 261 | + # For multiplexed, optionally skip a CreateSessionRequest |
| 262 | + if ( |
| 263 | + mux_enabled |
| 264 | + and idx < len(requests) |
| 265 | + and isinstance(requests[idx], CreateSessionRequest) |
| 266 | + ): |
| 267 | + idx += 1 |
| 268 | + else: |
| 269 | + if mux_enabled: |
| 270 | + self.assertTrue( |
| 271 | + isinstance(requests[idx], BatchCreateSessionsRequest), |
| 272 | + f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}", |
| 273 | + ) |
| 274 | + idx += 1 |
| 275 | + self.assertTrue( |
| 276 | + isinstance(requests[idx], CreateSessionRequest), |
| 277 | + f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}", |
| 278 | + ) |
| 279 | + idx += 1 |
| 280 | + else: |
| 281 | + self.assertTrue( |
| 282 | + isinstance(requests[idx], BatchCreateSessionsRequest), |
| 283 | + f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}", |
| 284 | + ) |
| 285 | + idx += 1 |
| 286 | + # Check the rest of the expected request types |
| 287 | + for expected_type in expected_types: |
| 288 | + self.assertTrue( |
| 289 | + isinstance(requests[idx], expected_type), |
| 290 | + f"Expected {expected_type} at index {idx}, got {type(requests[idx])}", |
| 291 | + ) |
| 292 | + idx += 1 |
| 293 | + self.assertEqual( |
| 294 | + idx, len(requests), f"Expected {idx} requests, got {len(requests)}" |
| 295 | + ) |
| 296 | + |
| 297 | + def adjust_request_id_sequence(self, expected_segments, requests, transaction_type): |
| 298 | + """Adjust expected request ID sequence numbers based on actual session creation requests. |
| 299 | +
|
| 300 | + Args: |
| 301 | + expected_segments: List of expected (method, (sequence_numbers)) tuples |
| 302 | + requests: List of actual requests from spanner_service.requests |
| 303 | + transaction_type: TransactionType enum value to check multiplexed session status |
| 304 | +
|
| 305 | + Returns: |
| 306 | + List of adjusted expected segments with corrected sequence numbers |
| 307 | + """ |
| 308 | + from google.cloud.spanner_v1 import ( |
| 309 | + BatchCreateSessionsRequest, |
| 310 | + CreateSessionRequest, |
| 311 | + ExecuteSqlRequest, |
| 312 | + BeginTransactionRequest, |
| 313 | + ) |
| 314 | + |
| 315 | + # Count session creation requests that come before the first non-session request |
| 316 | + session_requests_before = 0 |
| 317 | + for req in requests: |
| 318 | + if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)): |
| 319 | + session_requests_before += 1 |
| 320 | + elif isinstance(req, (ExecuteSqlRequest, BeginTransactionRequest)): |
| 321 | + break |
| 322 | + |
| 323 | + # For multiplexed sessions, we expect 2 session requests (BatchCreateSessions + CreateSession) |
| 324 | + # For non-multiplexed, we expect 1 session request (BatchCreateSessions) |
| 325 | + mux_enabled = is_multiplexed_enabled(transaction_type) |
| 326 | + expected_session_requests = 2 if mux_enabled else 1 |
| 327 | + extra_session_requests = session_requests_before - expected_session_requests |
| 328 | + |
| 329 | + # Adjust sequence numbers based on extra session requests |
| 330 | + adjusted_segments = [] |
| 331 | + for method, seq_nums in expected_segments: |
| 332 | + # Adjust the sequence number (5th element in the tuple) |
| 333 | + adjusted_seq_nums = list(seq_nums) |
| 334 | + adjusted_seq_nums[4] += extra_session_requests |
| 335 | + adjusted_segments.append((method, tuple(adjusted_seq_nums))) |
| 336 | + |
| 337 | + return adjusted_segments |
0 commit comments