diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8ba7e81ef5f..0deead1447aa 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,12 +3,14 @@ import asyncio import gc +import hashlib import importlib import inspect import json import multiprocessing import multiprocessing.forkserver as forkserver import os +import secrets import signal import socket import tempfile @@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: class AuthenticationMiddleware: """ Pure ASGI middleware that authenticates each request by checking - if the Authorization header exists and equals "Bearer {api_key}". + if the Authorization Bearer token exists and equals anyof "{api_key}". Notes ----- @@ -1263,7 +1265,26 @@ class AuthenticationMiddleware: def __init__(self, app: ASGIApp, tokens: list[str]) -> None: self.app = app - self.api_tokens = {f"Bearer {token}" for token in tokens} + self.api_tokens = [ + hashlib.sha256(t.encode("utf-8")).digest() for t in tokens + ] + + def verify_token(self, headers: Headers) -> bool: + authorization_header_value = headers.get("Authorization") + if not authorization_header_value: + return False + + scheme, _, param = authorization_header_value.partition(" ") + if scheme.lower() != "bearer": + return False + + param_hash = hashlib.sha256(param.encode("utf-8")).digest() + + token_match = False + for token_hash in self.api_tokens: + token_match |= secrets.compare_digest(param_hash, token_hash) + + return token_match def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: @@ -1276,8 +1297,7 @@ def __call__(self, scope: Scope, receive: Receive, url_path = URL(scope=scope).path.removeprefix(root_path) headers = Headers(scope=scope) # Type narrow to satisfy mypy. - if url_path.startswith("/v1") and headers.get( - "Authorization") not in self.api_tokens: + if url_path.startswith("/v1") and not self.verify_token(headers): response = JSONResponse(content={"error": "Unauthorized"}, status_code=401) return response(scope, receive, send)