|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | from collections import defaultdict |
| 16 | +import threading |
| 17 | + |
16 | 18 | from grpc_interceptor import ClientInterceptor |
17 | 19 | from google.api_core.exceptions import Aborted |
18 | 20 |
|
@@ -63,3 +65,30 @@ def reset(self): |
63 | 65 | self._method_to_abort = None |
64 | 66 | self._count = 0 |
65 | 67 | self._connection = None |
| 68 | + |
| 69 | + |
| 70 | +class XGoogRequestIDHeaderInterceptor(ClientInterceptor): |
| 71 | + def __init__(self): |
| 72 | + self._unary_req_segments = [] |
| 73 | + self._stream_req_segments = [] |
| 74 | + self.__lock = threading.Lock() |
| 75 | + |
| 76 | + def intercept(self, method, request_or_iterator, call_details): |
| 77 | + metadata = call_details.metadata |
| 78 | + x_goog_request_id = None |
| 79 | + for key, value in metadata: |
| 80 | + if key == "x-goog-spanner-request-id": |
| 81 | + x_goog_request_id = value |
| 82 | + break |
| 83 | + |
| 84 | + if not x_goog_request_id: |
| 85 | + raise Exception(f"Missing {x_goog_request_id}") |
| 86 | + |
| 87 | + streaming = hasattr(request_or_iterator, "__iter__", False) |
| 88 | + with self.__lock: |
| 89 | + if streaming: |
| 90 | + self._stream_req_segments.append(x_goog_request_id) |
| 91 | + else: |
| 92 | + self._unary_req_segments.append(x_goog_request_id) |
| 93 | + |
| 94 | + return method(request_or_iterator, call_details) |
0 commit comments