Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import asyncio
import gc
import hashlib
import importlib
import inspect
import json
import multiprocessing
import multiprocessing.forkserver as forkserver
import os
import secrets
import signal
import socket
import tempfile
Expand Down Expand Up @@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
if the Authorization header exists and equals "Bearer {api_key}".
if the Authorization Bearer token exists and equals anyof "{api_key}".

Notes
-----
Expand All @@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:

def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app
self.api_tokens = {f"Bearer {token}" for token in tokens}
self.api_tokens = [
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
]

def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False

scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False

param_hash = hashlib.sha256(param.encode("utf-8")).digest()

token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)

return token_match

def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]:
Expand All @@ -1276,8 +1297,7 @@ def __call__(self, scope: Scope, receive: Receive,
url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope)
# Type narrow to satisfy mypy.
if url_path.startswith("/v1") and headers.get(
"Authorization") not in self.api_tokens:
if url_path.startswith("/v1") and not self.verify_token(headers):
response = JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return response(scope, receive, send)
Expand Down