3
3
4
4
import asyncio
5
5
import gc
6
+ import hashlib
6
7
import importlib
7
8
import inspect
8
9
import json
9
10
import multiprocessing
10
11
import multiprocessing .forkserver as forkserver
11
12
import os
13
+ import secrets
12
14
import signal
13
15
import socket
14
16
import tempfile
@@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
1252
1254
class AuthenticationMiddleware :
1253
1255
"""
1254
1256
Pure ASGI middleware that authenticates each request by checking
1255
- if the Authorization header exists and equals "Bearer {api_key}".
1257
+ if the Authorization Bearer token exists and equals anyof " {api_key}".
1256
1258
1257
1259
Notes
1258
1260
-----
@@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
1263
1265
1264
1266
def __init__ (self , app : ASGIApp , tokens : list [str ]) -> None :
1265
1267
self .app = app
1266
- self .api_tokens = {f"Bearer { token } " for token in tokens }
1268
+ self .api_tokens = [
1269
+ hashlib .sha256 (t .encode ("utf-8" )).digest () for t in tokens
1270
+ ]
1271
+
1272
+ def verify_token (self , headers : Headers ) -> bool :
1273
+ authorization_header_value = headers .get ("Authorization" )
1274
+ if not authorization_header_value :
1275
+ return False
1276
+
1277
+ scheme , _ , param = authorization_header_value .partition (" " )
1278
+ if scheme .lower () != "bearer" :
1279
+ return False
1280
+
1281
+ param_hash = hashlib .sha256 (param .encode ("utf-8" )).digest ()
1282
+
1283
+ token_match = False
1284
+ for token_hash in self .api_tokens :
1285
+ token_match |= secrets .compare_digest (param_hash , token_hash )
1286
+
1287
+ return token_match
1267
1288
1268
1289
def __call__ (self , scope : Scope , receive : Receive ,
1269
1290
send : Send ) -> Awaitable [None ]:
@@ -1276,8 +1297,7 @@ def __call__(self, scope: Scope, receive: Receive,
1276
1297
url_path = URL (scope = scope ).path .removeprefix (root_path )
1277
1298
headers = Headers (scope = scope )
1278
1299
# Type narrow to satisfy mypy.
1279
- if url_path .startswith ("/v1" ) and headers .get (
1280
- "Authorization" ) not in self .api_tokens :
1300
+ if url_path .startswith ("/v1" ) and not self .verify_token (headers ):
1281
1301
response = JSONResponse (content = {"error" : "Unauthorized" },
1282
1302
status_code = 401 )
1283
1303
return response (scope , receive , send )
0 commit comments