import logging from typing import Literal from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import ( BearerAuthBackend, RequireAuthMiddleware, ) from mcp.server.fastmcp import FastMCP from mcp.server.websocket import websocket_server from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.routing import Mount, Route from starlette.websockets import WebSocket from ..config import settings logger = logging.getLogger(__name__) class BetterFastMCP(FastMCP): def run( self, transport: Literal["stdio", "sse", "streamable-http", "ws"] = "stdio", mount_path: str | None = None, ) -> None: import anyio if transport == "ws": anyio.run(self.run_ws_async) else: super().run(transport=transport, mount_path=mount_path) async def run_ws_async(self) -> None: """Run the server using WebSocket transport.""" import uvicorn starlette_app = self.ws_app() config = uvicorn.Config( app=starlette_app, host=self.settings.host, port=self.settings.port, log_level=self.settings.log_level.lower(), ) server = uvicorn.Server(config) await server.serve() def ws_app(self) -> Starlette: """Return an instance of the Websocket server app.""" async def handle_ws(websocket: WebSocket): async with websocket_server( websocket.scope, websocket.receive, websocket.send ) as (ws_read_stream, ws_write_stream): await self._mcp_server.run( ws_read_stream, ws_write_stream, self._mcp_server.create_initialization_options(), raise_exceptions=self.settings.debug, ) # Create routes routes: list[Route | Mount] = [] middleware: list[Middleware] = [] required_scopes = [] # Set up auth if configured if self.settings.auth: required_scopes = self.settings.auth.required_scopes or [] # Add auth middleware if token verifier is available if self._token_verifier: middleware = [ Middleware( AuthenticationMiddleware, backend=BearerAuthBackend(self._token_verifier), ), Middleware(AuthContextMiddleware), ] # Add auth endpoints if auth server provider is configured if self._auth_server_provider: from mcp.server.auth.routes import create_auth_routes routes.extend( create_auth_routes( provider=self._auth_server_provider, issuer_url=self.settings.auth.issuer_url, service_documentation_url=self.settings.auth.service_documentation_url, client_registration_options=self.settings.auth.client_registration_options, revocation_options=self.settings.auth.revocation_options, ) ) # Set up routes with or without auth if self._token_verifier: # Determine resource metadata URL resource_metadata_url = None if self.settings.auth and self.settings.auth.resource_server_url: from pydantic import AnyHttpUrl resource_metadata_url = AnyHttpUrl( str(self.settings.auth.resource_server_url).rstrip("/") + "/.well-known/oauth-protected-resource" ) routes.append( Route( settings.websocket_path, endpoint=RequireAuthMiddleware( handle_ws, required_scopes, resource_metadata_url ), ) ) else: # Auth is disabled, no wrapper needed routes.append( Route( settings.websocket_path, endpoint=handle_ws, ) ) # Add protected resource metadata endpoint if configured as RS if self.settings.auth and self.settings.auth.resource_server_url: from mcp.server.auth.handlers.metadata import ( ProtectedResourceMetadataHandler, ) from mcp.server.auth.routes import cors_middleware from mcp.shared.auth import ProtectedResourceMetadata protected_resource_metadata = ProtectedResourceMetadata( resource=self.settings.auth.resource_server_url, authorization_servers=[self.settings.auth.issuer_url], scopes_supported=self.settings.auth.required_scopes, ) routes.append( Route( "/.well-known/oauth-protected-resource", endpoint=cors_middleware( ProtectedResourceMetadataHandler( protected_resource_metadata ).handle, ["GET", "OPTIONS"], ), methods=["GET", "OPTIONS"], ) ) routes.extend(self._custom_starlette_routes) return Starlette( debug=self.settings.debug, routes=routes, middleware=middleware, lifespan=lambda app: self.session_manager.run(), )