diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index e39b4ef1e..e5aac0efc 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -113,6 +113,148 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse): }, ) + async def _handle_authorization_code( + self, client_info: Any, token_request: AuthorizationCodeRequest + ) -> TokenSuccessResponse | TokenErrorResponse: + auth_code = await self.provider.load_authorization_code(client_info, token_request.code) + if auth_code is None or auth_code.client_id != token_request.client_id: + # if code belongs to different client, pretend it doesn't exist + return TokenErrorResponse( + error="invalid_grant", + error_description="authorization code does not exist", + ) + + # make auth codes expire after a deadline + # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 + if auth_code.expires_at < time.time(): + return TokenErrorResponse( + error="invalid_grant", + error_description="authorization code has expired", + ) + + # verify redirect_uri doesn't change between /authorize and /tokens + # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 + if auth_code.redirect_uri_provided_explicitly: + authorize_request_redirect_uri = auth_code.redirect_uri + else: + authorize_request_redirect_uri = None + + # Convert both sides to strings for comparison to handle AnyUrl vs string issues + token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None + auth_redirect_str = ( + str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None + ) + + if token_redirect_str != auth_redirect_str: + return TokenErrorResponse( + error="invalid_request", + error_description=("redirect_uri did not match the one used when creating auth code"), + ) + + # Verify PKCE code verifier + sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() + hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + + if hashed_code_verifier != auth_code.code_challenge: + # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 + return TokenErrorResponse( + error="invalid_grant", + error_description="incorrect code_verifier", + ) + + try: + # Exchange authorization code for tokens + tokens = await self.provider.exchange_authorization_code(client_info, auth_code) + except TokenError as e: + return TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + + return TokenSuccessResponse(root=tokens) + + async def _handle_client_credentials( + self, client_info: Any, token_request: ClientCredentialsRequest + ) -> TokenSuccessResponse | TokenErrorResponse: + scopes = ( + token_request.scope.split(" ") + if token_request.scope + else client_info.scope.split(" ") + if client_info.scope + else [] + ) + try: + tokens = await self.provider.exchange_client_credentials(client_info, scopes) + except TokenError as e: + return TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + + return TokenSuccessResponse(root=tokens) + + async def _handle_token_exchange( + self, client_info: Any, token_request: TokenExchangeRequest + ) -> TokenSuccessResponse | TokenErrorResponse: + scopes = token_request.scope.split(" ") if token_request.scope else [] + try: + tokens = await self.provider.exchange_token( + client_info, + token_request.subject_token, + token_request.subject_token_type, + token_request.actor_token, + token_request.actor_token_type, + scopes, + token_request.audience, + token_request.resource, + ) + except TokenError as e: + return TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + + return TokenSuccessResponse(root=tokens) + + async def _handle_refresh_token( + self, client_info: Any, token_request: RefreshTokenRequest + ) -> TokenSuccessResponse | TokenErrorResponse: + refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) + if refresh_token is None or refresh_token.client_id != token_request.client_id: + # if token belongs to a different client, pretend it doesn't exist + return TokenErrorResponse( + error="invalid_grant", + error_description="refresh token does not exist", + ) + + if refresh_token.expires_at and refresh_token.expires_at < time.time(): + # if the refresh token has expired, pretend it doesn't exist + return TokenErrorResponse( + error="invalid_grant", + error_description="refresh token has expired", + ) + + # Parse scopes if provided + scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes + + for scope in scopes: + if scope not in refresh_token.scopes: + return TokenErrorResponse( + error="invalid_scope", + error_description=(f"cannot request scope `{scope}` not provided by refresh token"), + ) + + try: + # Exchange refresh token for new tokens + tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) + except TokenError as e: + return TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + + return TokenSuccessResponse(root=tokens) + async def handle(self, request: Request): try: form_data = await request.form() @@ -146,155 +288,17 @@ async def handle(self, request: Request): ) ) - tokens: OAuthToken - match token_request: case AuthorizationCodeRequest(): - auth_code = await self.provider.load_authorization_code(client_info, token_request.code) - if auth_code is None or auth_code.client_id != token_request.client_id: - # if code belongs to different client, pretend it doesn't exist - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="authorization code does not exist", - ) - ) - - # make auth codes expire after a deadline - # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 - if auth_code.expires_at < time.time(): - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="authorization code has expired", - ) - ) - - # verify redirect_uri doesn't change between /authorize and /tokens - # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 - if auth_code.redirect_uri_provided_explicitly: - authorize_request_redirect_uri = auth_code.redirect_uri - else: - authorize_request_redirect_uri = None - - # Convert both sides to strings for comparison to handle AnyUrl vs string issues - token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None - auth_redirect_str = ( - str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None - ) - - if token_redirect_str != auth_redirect_str: - return self.response( - TokenErrorResponse( - error="invalid_request", - error_description=("redirect_uri did not match the one used when creating auth code"), - ) - ) - - # Verify PKCE code verifier - sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() - hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") - - if hashed_code_verifier != auth_code.code_challenge: - # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="incorrect code_verifier", - ) - ) - - try: - # Exchange authorization code for tokens - tokens = await self.provider.exchange_authorization_code(client_info, auth_code) - except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) + result = await self._handle_authorization_code(client_info, token_request) case ClientCredentialsRequest(): - scopes = ( - token_request.scope.split(" ") - if token_request.scope - else client_info.scope.split(" ") - if client_info.scope - else [] - ) - try: - tokens = await self.provider.exchange_client_credentials(client_info, scopes) - except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) + result = await self._handle_client_credentials(client_info, token_request) case TokenExchangeRequest(): - scopes = token_request.scope.split(" ") if token_request.scope else [] - try: - tokens = await self.provider.exchange_token( - client_info, - token_request.subject_token, - token_request.subject_token_type, - token_request.actor_token, - token_request.actor_token_type, - scopes, - token_request.audience, - token_request.resource, - ) - except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) + result = await self._handle_token_exchange(client_info, token_request) case RefreshTokenRequest(): - refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) - if refresh_token is None or refresh_token.client_id != token_request.client_id: - # if token belongs to a different client, pretend it doesn't exist - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="refresh token does not exist", - ) - ) - - if refresh_token.expires_at and refresh_token.expires_at < time.time(): - # if the refresh token has expired, pretend it doesn't exist - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="refresh token has expired", - ) - ) - - # Parse scopes if provided - scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes - - for scope in scopes: - if scope not in refresh_token.scopes: - return self.response( - TokenErrorResponse( - error="invalid_scope", - error_description=(f"cannot request scope `{scope}` not provided by refresh token"), - ) - ) - - try: - # Exchange refresh token for new tokens - tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) - except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) - - return self.response(TokenSuccessResponse(root=tokens)) + result = await self._handle_refresh_token(client_info, token_request) + + return self.response(result)