feat: add /ws

This commit is contained in:
Sun-ZhenXing
2025-07-21 00:09:53 +08:00
parent bcb5994356
commit 48f2efef7a
10 changed files with 347 additions and 51 deletions

View File

@@ -2,11 +2,12 @@ import argparse
import sys
from .__about__ import __module_name__, __version__
from .app import MCP_MAP
from .config import settings
def main():
from .app import MCP_MAP
parser = argparse.ArgumentParser(description="MCP Server")
parser.add_argument(
@@ -14,6 +15,13 @@ def main():
action="store_true",
help="Run the server with STDIO (default: False)",
)
parser.add_argument(
"--mcp",
type=str,
default=settings.default_mcp,
choices=list(MCP_MAP.keys()),
help=f"Select the MCP to run in STDIO mode (default: {settings.default_mcp})",
)
parser.add_argument(
"--host",
default=settings.default_host,
@@ -44,7 +52,10 @@ def main():
sys.exit(0)
if args.stdio:
mcp = MCP_MAP[settings.default_mcp]
mcp = MCP_MAP.get(args.mcp)
if mcp is None:
print(f"Error: MCP '{args.mcp}' not found.")
sys.exit(1)
mcp.run()
else:
import uvicorn

View File

@@ -1,39 +1,31 @@
from operator import add, mul, sub, truediv
from mcp.server.fastmcp import FastMCP
from mcp_template_python.lib.better_mcp import BetterFastMCP
from ..config import settings
mcp = FastMCP("math", settings=settings.instructions)
mcp = BetterFastMCP("math", settings=settings.instructions)
@mcp.tool()
async def add_nums(a: float, b: float) -> float:
"""
Adds two numbers.
"""
async def add_num(a: float, b: float) -> float:
"""Adds two numbers."""
return add(a, b)
@mcp.tool()
async def sub_nums(a: float, b: float) -> float:
"""
Subtracts the second number from the first.
"""
async def sub_num(a: float, b: float) -> float:
"""Subtracts the second number from the first."""
return sub(a, b)
@mcp.tool()
async def mul_nums(a: float, b: float) -> float:
"""
Multiplies two numbers.
"""
async def mul_num(a: float, b: float) -> float:
"""Multiplies two numbers."""
return mul(a, b)
@mcp.tool()
async def div_nums(a: float, b: float) -> float:
"""
Divides the first number by the second.
"""
async def div_num(a: float, b: float) -> float:
"""Divides the first number by the second."""
return truediv(a, b)

View File

@@ -6,12 +6,6 @@ class Settings(BaseSettings):
Configuration settings for the MCP template application.
"""
default_mcp: str = "math"
default_host: str = "127.0.0.1"
default_port: int = 3001
instructions: str | None = None
model_config = SettingsConfigDict(
env_prefix="MCP_",
env_file=".env",
@@ -19,5 +13,38 @@ class Settings(BaseSettings):
extra="allow",
)
app_title: str = "MCP Template Application"
"""Title of the MCP application, defaults to 'MCP Template Application'."""
app_description: str = "A template application for MCP using FastAPI."
"""Description of the MCP application, defaults to 'A template application for MCP using FastAPI.'"""
default_mcp: str = "math"
"""Default MCP to be used by the application."""
default_host: str = "127.0.0.1"
"""Default host for the MCP server, defaults to 127.0.0.1."""
default_port: int = 3001
"""Default port for the MCP server, defaults to 3001."""
instructions: str | None = None
"""Instructions to be used by the MCP server, defaults to None."""
enable_helpers_router: bool = True
"""Enable the helpers router for the MCP server."""
enable_sse: bool = True
"""Enable Server-Sent Events (SSE) for the MCP server."""
enable_streamable_http: bool = True
"""Enable streamable HTTP for the MCP server."""
enable_websocket: bool = False
"""Enable WebSocket for the MCP server."""
websocket_path: str = "/ws"
"""Path for the WebSocket endpoint."""
settings = Settings()

View File

View File

@@ -0,0 +1,180 @@
import logging
from typing import Literal
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.bearer_auth import (
BearerAuthBackend,
RequireAuthMiddleware,
)
from mcp.server.fastmcp import FastMCP
from mcp.server.websocket import websocket_server
from mcp.types import ToolAnnotations
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.routing import Mount, Route
from starlette.websockets import WebSocket
from ..config import settings
logger = logging.getLogger(__name__)
class BetterFastMCP(FastMCP):
def run(
self,
transport: Literal["stdio", "sse", "streamable-http", "ws"] = "stdio",
mount_path: str | None = None,
) -> None:
import anyio
if transport == "ws":
anyio.run(self.run_ws_async)
else:
super().run(transport=transport, mount_path=mount_path)
async def run_ws_async(self) -> None:
"""Run the server using WebSocket transport."""
import uvicorn
starlette_app = self.ws_app()
config = uvicorn.Config(
app=starlette_app,
host=self.settings.host,
port=self.settings.port,
log_level=self.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
def ws_app(self) -> Starlette:
"""Return an instance of the Websocket server app."""
async def handle_ws(websocket: WebSocket):
async with websocket_server(
websocket.scope, websocket.receive, websocket.send
) as (ws_read_stream, ws_write_stream):
await self._mcp_server.run(
ws_read_stream,
ws_write_stream,
self._mcp_server.create_initialization_options(),
raise_exceptions=self.settings.debug,
)
# Create routes
routes: list[Route | Mount] = []
middleware: list[Middleware] = []
required_scopes = []
# Set up auth if configured
if self.settings.auth:
required_scopes = self.settings.auth.required_scopes or []
# Add auth middleware if token verifier is available
if self._token_verifier:
middleware = [
Middleware(
AuthenticationMiddleware,
backend=BearerAuthBackend(self._token_verifier),
),
Middleware(AuthContextMiddleware),
]
# Add auth endpoints if auth server provider is configured
if self._auth_server_provider:
from mcp.server.auth.routes import create_auth_routes
routes.extend(
create_auth_routes(
provider=self._auth_server_provider,
issuer_url=self.settings.auth.issuer_url,
service_documentation_url=self.settings.auth.service_documentation_url,
client_registration_options=self.settings.auth.client_registration_options,
revocation_options=self.settings.auth.revocation_options,
)
)
# Set up routes with or without auth
if self._token_verifier:
# Determine resource metadata URL
resource_metadata_url = None
if self.settings.auth and self.settings.auth.resource_server_url:
from pydantic import AnyHttpUrl
resource_metadata_url = AnyHttpUrl(
str(self.settings.auth.resource_server_url).rstrip("/")
+ "/.well-known/oauth-protected-resource"
)
routes.append(
Route(
settings.websocket_path,
endpoint=RequireAuthMiddleware(
handle_ws, required_scopes, resource_metadata_url
),
)
)
else:
# Auth is disabled, no wrapper needed
routes.append(
Route(
settings.websocket_path,
endpoint=handle_ws,
)
)
# Add protected resource metadata endpoint if configured as RS
if self.settings.auth and self.settings.auth.resource_server_url:
from mcp.server.auth.handlers.metadata import (
ProtectedResourceMetadataHandler,
)
from mcp.server.auth.routes import cors_middleware
from mcp.shared.auth import ProtectedResourceMetadata
protected_resource_metadata = ProtectedResourceMetadata(
resource=self.settings.auth.resource_server_url,
authorization_servers=[self.settings.auth.issuer_url],
scopes_supported=self.settings.auth.required_scopes,
)
routes.append(
Route(
"/.well-known/oauth-protected-resource",
endpoint=cors_middleware(
ProtectedResourceMetadataHandler(
protected_resource_metadata
).handle,
["GET", "OPTIONS"],
),
methods=["GET", "OPTIONS"],
)
)
routes.extend(self._custom_starlette_routes)
return Starlette(
debug=self.settings.debug,
routes=routes,
middleware=middleware,
lifespan=lambda app: self.session_manager.run(),
)
def better_tool(
self,
name: str | None = None,
title: str | None = None,
description: str | None = None,
annotations: ToolAnnotations | None = None,
structured_output: bool | None = None,
):
"""Decorator to register a tool.
TODO: Implement a better tool function decorator.
"""
# tool_mcp = self._tool_manager._tools
# existing = tool_mcp.get(name)
# if existing:
# if self._tool_manager.warn_on_duplicate_tools:
# logger.warning(f"Tool already exists: {tool.name}")
# return existing
# self._tools[tool.name] = tool
# return tool

View File

@@ -2,7 +2,9 @@ import contextlib
from fastapi import FastAPI
from .__about__ import __version__
from .app import MCP_MAP
from .config import settings
from .routers.helpers import router as helpers_router
@@ -14,13 +16,21 @@ async def lifespan(app: FastAPI):
yield
app = FastAPI(lifespan=lifespan)
app = FastAPI(
title=settings.app_title,
description=settings.app_description,
version=__version__,
lifespan=lifespan,
)
@app.get("/")
async def root():
"""Root endpoint."""
return {"message": "Welcome!"}
return {
"message": "Welcome!",
"tools": list(MCP_MAP.keys()),
}
@app.get("/health")
@@ -28,12 +38,16 @@ async def health():
"""Check the health of the server and list available tools."""
return {
"status": "healthy",
"tools": list(MCP_MAP.keys()),
}
app.include_router(helpers_router)
if settings.enable_helpers_router:
app.include_router(helpers_router)
for name, mcp in MCP_MAP.items():
app.mount(f"/{name}/compatible", mcp.sse_app())
app.mount(f"/{name}", mcp.streamable_http_app())
if settings.enable_sse:
app.mount(f"/{name}/compatible", mcp.sse_app())
if settings.enable_streamable_http:
app.mount(f"/{name}", mcp.streamable_http_app())
if settings.enable_websocket:
app.mount(f"/{name}/websocket", mcp.ws_app())