Files
mcp-template-python/src/mcp_template_python/lib/better_mcp.py
Sun-ZhenXing 48f2efef7a feat: add /ws
2025-07-21 00:09:53 +08:00

181 lines
6.3 KiB
Python

import logging
from typing import Literal
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.bearer_auth import (
BearerAuthBackend,
RequireAuthMiddleware,
)
from mcp.server.fastmcp import FastMCP
from mcp.server.websocket import websocket_server
from mcp.types import ToolAnnotations
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.routing import Mount, Route
from starlette.websockets import WebSocket
from ..config import settings
logger = logging.getLogger(__name__)
class BetterFastMCP(FastMCP):
def run(
self,
transport: Literal["stdio", "sse", "streamable-http", "ws"] = "stdio",
mount_path: str | None = None,
) -> None:
import anyio
if transport == "ws":
anyio.run(self.run_ws_async)
else:
super().run(transport=transport, mount_path=mount_path)
async def run_ws_async(self) -> None:
"""Run the server using WebSocket transport."""
import uvicorn
starlette_app = self.ws_app()
config = uvicorn.Config(
app=starlette_app,
host=self.settings.host,
port=self.settings.port,
log_level=self.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
def ws_app(self) -> Starlette:
"""Return an instance of the Websocket server app."""
async def handle_ws(websocket: WebSocket):
async with websocket_server(
websocket.scope, websocket.receive, websocket.send
) as (ws_read_stream, ws_write_stream):
await self._mcp_server.run(
ws_read_stream,
ws_write_stream,
self._mcp_server.create_initialization_options(),
raise_exceptions=self.settings.debug,
)
# Create routes
routes: list[Route | Mount] = []
middleware: list[Middleware] = []
required_scopes = []
# Set up auth if configured
if self.settings.auth:
required_scopes = self.settings.auth.required_scopes or []
# Add auth middleware if token verifier is available
if self._token_verifier:
middleware = [
Middleware(
AuthenticationMiddleware,
backend=BearerAuthBackend(self._token_verifier),
),
Middleware(AuthContextMiddleware),
]
# Add auth endpoints if auth server provider is configured
if self._auth_server_provider:
from mcp.server.auth.routes import create_auth_routes
routes.extend(
create_auth_routes(
provider=self._auth_server_provider,
issuer_url=self.settings.auth.issuer_url,
service_documentation_url=self.settings.auth.service_documentation_url,
client_registration_options=self.settings.auth.client_registration_options,
revocation_options=self.settings.auth.revocation_options,
)
)
# Set up routes with or without auth
if self._token_verifier:
# Determine resource metadata URL
resource_metadata_url = None
if self.settings.auth and self.settings.auth.resource_server_url:
from pydantic import AnyHttpUrl
resource_metadata_url = AnyHttpUrl(
str(self.settings.auth.resource_server_url).rstrip("/")
+ "/.well-known/oauth-protected-resource"
)
routes.append(
Route(
settings.websocket_path,
endpoint=RequireAuthMiddleware(
handle_ws, required_scopes, resource_metadata_url
),
)
)
else:
# Auth is disabled, no wrapper needed
routes.append(
Route(
settings.websocket_path,
endpoint=handle_ws,
)
)
# Add protected resource metadata endpoint if configured as RS
if self.settings.auth and self.settings.auth.resource_server_url:
from mcp.server.auth.handlers.metadata import (
ProtectedResourceMetadataHandler,
)
from mcp.server.auth.routes import cors_middleware
from mcp.shared.auth import ProtectedResourceMetadata
protected_resource_metadata = ProtectedResourceMetadata(
resource=self.settings.auth.resource_server_url,
authorization_servers=[self.settings.auth.issuer_url],
scopes_supported=self.settings.auth.required_scopes,
)
routes.append(
Route(
"/.well-known/oauth-protected-resource",
endpoint=cors_middleware(
ProtectedResourceMetadataHandler(
protected_resource_metadata
).handle,
["GET", "OPTIONS"],
),
methods=["GET", "OPTIONS"],
)
)
routes.extend(self._custom_starlette_routes)
return Starlette(
debug=self.settings.debug,
routes=routes,
middleware=middleware,
lifespan=lambda app: self.session_manager.run(),
)
def better_tool(
self,
name: str | None = None,
title: str | None = None,
description: str | None = None,
annotations: ToolAnnotations | None = None,
structured_output: bool | None = None,
):
"""Decorator to register a tool.
TODO: Implement a better tool function decorator.
"""
# tool_mcp = self._tool_manager._tools
# existing = tool_mcp.get(name)
# if existing:
# if self._tool_manager.warn_on_duplicate_tools:
# logger.warning(f"Tool already exists: {tool.name}")
# return existing
# self._tools[tool.name] = tool
# return tool