|
6 | 6 | import urllib |
7 | 7 | from typing import Any, Dict |
8 | 8 |
|
9 | | -from asgi_correlation_id import CorrelationIdMiddleware, correlation_id # noqa |
10 | 9 | from asgiref.typing import ( |
11 | 10 | ASGI3Application, |
12 | 11 | ASGIReceiveCallable, |
|
15 | 14 | HTTPScope, |
16 | 15 | ) |
17 | 16 |
|
18 | | -from ..logging import JsonLogFormatter |
| 17 | +from ..logging import JsonLogFormatter, get_or_generate_request_id, request_id_context |
| 18 | + |
| 19 | + |
| 20 | +class RequestIdMiddleware: |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + app: ASGI3Application, |
| 24 | + ) -> None: |
| 25 | + self.app = app |
| 26 | + |
| 27 | + async def __call__( |
| 28 | + self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable |
| 29 | + ) -> None: |
| 30 | + if scope["type"] != "http": |
| 31 | + return await self.app(scope, receive, send) |
| 32 | + |
| 33 | + headers = {} |
| 34 | + for name, value in scope["headers"]: |
| 35 | + header_key = name.decode("latin1").lower() |
| 36 | + header_val = value.decode("latin1") |
| 37 | + headers[header_key] = header_val |
| 38 | + |
| 39 | + rid = get_or_generate_request_id( |
| 40 | + headers, |
| 41 | + header_name=getattr( |
| 42 | + scope["app"].state, "DOCKERFLOW_REQUEST_ID_HEADER_NAME", None |
| 43 | + ), |
| 44 | + ) |
| 45 | + request_id_context.set(rid) |
| 46 | + |
| 47 | + await self.app(scope, receive, send) |
19 | 48 |
|
20 | 49 |
|
21 | 50 | class MozlogRequestSummaryLogger: |
@@ -75,7 +104,7 @@ def _format(self, scope: HTTPScope, info) -> Dict[str, Any]: |
75 | 104 | "code": info["response"]["status"], |
76 | 105 | "lang": info["request_headers"].get("accept-language"), |
77 | 106 | "t": int(request_duration_ms), |
78 | | - "rid": correlation_id.get(), |
| 107 | + "rid": request_id_context.get(), |
79 | 108 | } |
80 | 109 |
|
81 | 110 | if getattr(scope["app"].state, "DOCKERFLOW_SUMMARY_LOG_QUERYSTRING", False): |
|
0 commit comments