Skip to content

Commit 3f5d902

Browse files
russellbNiuBlibing
andauthored
Validate API tokens in constant time (#25781)
Signed-off-by: rentianyue-jk <[email protected]> Signed-off-by: Russell Bryant <[email protected]> Co-authored-by: rentianyue-jk <[email protected]>
1 parent 27d7638 commit 3f5d902

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
import asyncio
55
import gc
6+
import hashlib
67
import importlib
78
import inspect
89
import json
910
import multiprocessing
1011
import multiprocessing.forkserver as forkserver
1112
import os
13+
import secrets
1214
import signal
1315
import socket
1416
import tempfile
@@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
12521254
class AuthenticationMiddleware:
12531255
"""
12541256
Pure ASGI middleware that authenticates each request by checking
1255-
if the Authorization header exists and equals "Bearer {api_key}".
1257+
if the Authorization Bearer token exists and equals anyof "{api_key}".
12561258
12571259
Notes
12581260
-----
@@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
12631265

12641266
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
12651267
self.app = app
1266-
self.api_tokens = {f"Bearer {token}" for token in tokens}
1268+
self.api_tokens = [
1269+
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
1270+
]
1271+
1272+
def verify_token(self, headers: Headers) -> bool:
1273+
authorization_header_value = headers.get("Authorization")
1274+
if not authorization_header_value:
1275+
return False
1276+
1277+
scheme, _, param = authorization_header_value.partition(" ")
1278+
if scheme.lower() != "bearer":
1279+
return False
1280+
1281+
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
1282+
1283+
token_match = False
1284+
for token_hash in self.api_tokens:
1285+
token_match |= secrets.compare_digest(param_hash, token_hash)
1286+
1287+
return token_match
12671288

12681289
def __call__(self, scope: Scope, receive: Receive,
12691290
send: Send) -> Awaitable[None]:
@@ -1276,8 +1297,7 @@ def __call__(self, scope: Scope, receive: Receive,
12761297
url_path = URL(scope=scope).path.removeprefix(root_path)
12771298
headers = Headers(scope=scope)
12781299
# Type narrow to satisfy mypy.
1279-
if url_path.startswith("/v1") and headers.get(
1280-
"Authorization") not in self.api_tokens:
1300+
if url_path.startswith("/v1") and not self.verify_token(headers):
12811301
response = JSONResponse(content={"error": "Unauthorized"},
12821302
status_code=401)
12831303
return response(scope, receive, send)

0 commit comments

Comments
 (0)