Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 148 additions & 144 deletions src/mcp/server/auth/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,148 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
},
)

async def _handle_authorization_code(
self, client_info: Any, token_request: AuthorizationCodeRequest
) -> TokenSuccessResponse | TokenErrorResponse:
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
if auth_code is None or auth_code.client_id != token_request.client_id:
# if code belongs to different client, pretend it doesn't exist
return TokenErrorResponse(
error="invalid_grant",
error_description="authorization code does not exist",
)

# make auth codes expire after a deadline
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
if auth_code.expires_at < time.time():
return TokenErrorResponse(
error="invalid_grant",
error_description="authorization code has expired",
)

# verify redirect_uri doesn't change between /authorize and /tokens
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
if auth_code.redirect_uri_provided_explicitly:
authorize_request_redirect_uri = auth_code.redirect_uri
else:
authorize_request_redirect_uri = None

# Convert both sides to strings for comparison to handle AnyUrl vs string issues
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
auth_redirect_str = (
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
)

if token_redirect_str != auth_redirect_str:
return TokenErrorResponse(
error="invalid_request",
error_description=("redirect_uri did not match the one used when creating auth code"),
)

# Verify PKCE code verifier
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")

if hashed_code_verifier != auth_code.code_challenge:
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
return TokenErrorResponse(
error="invalid_grant",
error_description="incorrect code_verifier",
)

try:
# Exchange authorization code for tokens
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
except TokenError as e:
return TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)

return TokenSuccessResponse(root=tokens)

async def _handle_client_credentials(
self, client_info: Any, token_request: ClientCredentialsRequest
) -> TokenSuccessResponse | TokenErrorResponse:
scopes = (
token_request.scope.split(" ")
if token_request.scope
else client_info.scope.split(" ")
if client_info.scope
else []
)
try:
tokens = await self.provider.exchange_client_credentials(client_info, scopes)
except TokenError as e:
return TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)

return TokenSuccessResponse(root=tokens)

async def _handle_token_exchange(
self, client_info: Any, token_request: TokenExchangeRequest
) -> TokenSuccessResponse | TokenErrorResponse:
scopes = token_request.scope.split(" ") if token_request.scope else []
try:
tokens = await self.provider.exchange_token(
client_info,
token_request.subject_token,
token_request.subject_token_type,
token_request.actor_token,
token_request.actor_token_type,
scopes,
token_request.audience,
token_request.resource,
)
except TokenError as e:
return TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)

return TokenSuccessResponse(root=tokens)

async def _handle_refresh_token(
self, client_info: Any, token_request: RefreshTokenRequest
) -> TokenSuccessResponse | TokenErrorResponse:
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
if refresh_token is None or refresh_token.client_id != token_request.client_id:
# if token belongs to a different client, pretend it doesn't exist
return TokenErrorResponse(
error="invalid_grant",
error_description="refresh token does not exist",
)

if refresh_token.expires_at and refresh_token.expires_at < time.time():
# if the refresh token has expired, pretend it doesn't exist
return TokenErrorResponse(
error="invalid_grant",
error_description="refresh token has expired",
)

# Parse scopes if provided
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes

for scope in scopes:
if scope not in refresh_token.scopes:
return TokenErrorResponse(
error="invalid_scope",
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
)

try:
# Exchange refresh token for new tokens
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
except TokenError as e:
return TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)

return TokenSuccessResponse(root=tokens)

async def handle(self, request: Request):
try:
form_data = await request.form()
Expand Down Expand Up @@ -146,155 +288,17 @@ async def handle(self, request: Request):
)
)

tokens: OAuthToken

match token_request:
case AuthorizationCodeRequest():
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
if auth_code is None or auth_code.client_id != token_request.client_id:
# if code belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="authorization code does not exist",
)
)

# make auth codes expire after a deadline
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
if auth_code.expires_at < time.time():
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="authorization code has expired",
)
)

# verify redirect_uri doesn't change between /authorize and /tokens
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
if auth_code.redirect_uri_provided_explicitly:
authorize_request_redirect_uri = auth_code.redirect_uri
else:
authorize_request_redirect_uri = None

# Convert both sides to strings for comparison to handle AnyUrl vs string issues
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
auth_redirect_str = (
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
)

if token_redirect_str != auth_redirect_str:
return self.response(
TokenErrorResponse(
error="invalid_request",
error_description=("redirect_uri did not match the one used when creating auth code"),
)
)

# Verify PKCE code verifier
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")

if hashed_code_verifier != auth_code.code_challenge:
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="incorrect code_verifier",
)
)

try:
# Exchange authorization code for tokens
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
result = await self._handle_authorization_code(client_info, token_request)

case ClientCredentialsRequest():
scopes = (
token_request.scope.split(" ")
if token_request.scope
else client_info.scope.split(" ")
if client_info.scope
else []
)
try:
tokens = await self.provider.exchange_client_credentials(client_info, scopes)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
result = await self._handle_client_credentials(client_info, token_request)

case TokenExchangeRequest():
scopes = token_request.scope.split(" ") if token_request.scope else []
try:
tokens = await self.provider.exchange_token(
client_info,
token_request.subject_token,
token_request.subject_token_type,
token_request.actor_token,
token_request.actor_token_type,
scopes,
token_request.audience,
token_request.resource,
)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
result = await self._handle_token_exchange(client_info, token_request)

case RefreshTokenRequest():
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
if refresh_token is None or refresh_token.client_id != token_request.client_id:
# if token belongs to a different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="refresh token does not exist",
)
)

if refresh_token.expires_at and refresh_token.expires_at < time.time():
# if the refresh token has expired, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="refresh token has expired",
)
)

# Parse scopes if provided
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes

for scope in scopes:
if scope not in refresh_token.scopes:
return self.response(
TokenErrorResponse(
error="invalid_scope",
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
)
)

try:
# Exchange refresh token for new tokens
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)

return self.response(TokenSuccessResponse(root=tokens))
result = await self._handle_refresh_token(client_info, token_request)

return self.response(result)