mirror of
https://github.com/Sun-ZhenXing/mcp-template-python.git
synced 2026-02-04 10:13:31 +00:00
feat: add /ws
This commit is contained in:
@@ -2,11 +2,12 @@ import argparse
|
||||
import sys
|
||||
|
||||
from .__about__ import __module_name__, __version__
|
||||
from .app import MCP_MAP
|
||||
from .config import settings
|
||||
|
||||
|
||||
def main():
|
||||
from .app import MCP_MAP
|
||||
|
||||
parser = argparse.ArgumentParser(description="MCP Server")
|
||||
|
||||
parser.add_argument(
|
||||
@@ -14,6 +15,13 @@ def main():
|
||||
action="store_true",
|
||||
help="Run the server with STDIO (default: False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mcp",
|
||||
type=str,
|
||||
default=settings.default_mcp,
|
||||
choices=list(MCP_MAP.keys()),
|
||||
help=f"Select the MCP to run in STDIO mode (default: {settings.default_mcp})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default=settings.default_host,
|
||||
@@ -44,7 +52,10 @@ def main():
|
||||
sys.exit(0)
|
||||
|
||||
if args.stdio:
|
||||
mcp = MCP_MAP[settings.default_mcp]
|
||||
mcp = MCP_MAP.get(args.mcp)
|
||||
if mcp is None:
|
||||
print(f"Error: MCP '{args.mcp}' not found.")
|
||||
sys.exit(1)
|
||||
mcp.run()
|
||||
else:
|
||||
import uvicorn
|
||||
|
||||
@@ -1,39 +1,31 @@
|
||||
from operator import add, mul, sub, truediv
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp_template_python.lib.better_mcp import BetterFastMCP
|
||||
|
||||
from ..config import settings
|
||||
|
||||
mcp = FastMCP("math", settings=settings.instructions)
|
||||
mcp = BetterFastMCP("math", settings=settings.instructions)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def add_nums(a: float, b: float) -> float:
|
||||
"""
|
||||
Adds two numbers.
|
||||
"""
|
||||
async def add_num(a: float, b: float) -> float:
|
||||
"""Adds two numbers."""
|
||||
return add(a, b)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def sub_nums(a: float, b: float) -> float:
|
||||
"""
|
||||
Subtracts the second number from the first.
|
||||
"""
|
||||
async def sub_num(a: float, b: float) -> float:
|
||||
"""Subtracts the second number from the first."""
|
||||
return sub(a, b)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def mul_nums(a: float, b: float) -> float:
|
||||
"""
|
||||
Multiplies two numbers.
|
||||
"""
|
||||
async def mul_num(a: float, b: float) -> float:
|
||||
"""Multiplies two numbers."""
|
||||
return mul(a, b)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def div_nums(a: float, b: float) -> float:
|
||||
"""
|
||||
Divides the first number by the second.
|
||||
"""
|
||||
async def div_num(a: float, b: float) -> float:
|
||||
"""Divides the first number by the second."""
|
||||
return truediv(a, b)
|
||||
|
||||
@@ -6,12 +6,6 @@ class Settings(BaseSettings):
|
||||
Configuration settings for the MCP template application.
|
||||
"""
|
||||
|
||||
default_mcp: str = "math"
|
||||
default_host: str = "127.0.0.1"
|
||||
default_port: int = 3001
|
||||
|
||||
instructions: str | None = None
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="MCP_",
|
||||
env_file=".env",
|
||||
@@ -19,5 +13,38 @@ class Settings(BaseSettings):
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
app_title: str = "MCP Template Application"
|
||||
"""Title of the MCP application, defaults to 'MCP Template Application'."""
|
||||
|
||||
app_description: str = "A template application for MCP using FastAPI."
|
||||
"""Description of the MCP application, defaults to 'A template application for MCP using FastAPI.'"""
|
||||
|
||||
default_mcp: str = "math"
|
||||
"""Default MCP to be used by the application."""
|
||||
|
||||
default_host: str = "127.0.0.1"
|
||||
"""Default host for the MCP server, defaults to 127.0.0.1."""
|
||||
|
||||
default_port: int = 3001
|
||||
"""Default port for the MCP server, defaults to 3001."""
|
||||
|
||||
instructions: str | None = None
|
||||
"""Instructions to be used by the MCP server, defaults to None."""
|
||||
|
||||
enable_helpers_router: bool = True
|
||||
"""Enable the helpers router for the MCP server."""
|
||||
|
||||
enable_sse: bool = True
|
||||
"""Enable Server-Sent Events (SSE) for the MCP server."""
|
||||
|
||||
enable_streamable_http: bool = True
|
||||
"""Enable streamable HTTP for the MCP server."""
|
||||
|
||||
enable_websocket: bool = False
|
||||
"""Enable WebSocket for the MCP server."""
|
||||
|
||||
websocket_path: str = "/ws"
|
||||
"""Path for the WebSocket endpoint."""
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
0
src/mcp_template_python/lib/__init__.py
Normal file
0
src/mcp_template_python/lib/__init__.py
Normal file
180
src/mcp_template_python/lib/better_mcp.py
Normal file
180
src/mcp_template_python/lib/better_mcp.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
|
||||
from mcp.server.auth.middleware.bearer_auth import (
|
||||
BearerAuthBackend,
|
||||
RequireAuthMiddleware,
|
||||
)
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.websocket import websocket_server
|
||||
from mcp.types import ToolAnnotations
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from starlette.routing import Mount, Route
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from ..config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BetterFastMCP(FastMCP):
|
||||
def run(
|
||||
self,
|
||||
transport: Literal["stdio", "sse", "streamable-http", "ws"] = "stdio",
|
||||
mount_path: str | None = None,
|
||||
) -> None:
|
||||
import anyio
|
||||
|
||||
if transport == "ws":
|
||||
anyio.run(self.run_ws_async)
|
||||
else:
|
||||
super().run(transport=transport, mount_path=mount_path)
|
||||
|
||||
async def run_ws_async(self) -> None:
|
||||
"""Run the server using WebSocket transport."""
|
||||
import uvicorn
|
||||
|
||||
starlette_app = self.ws_app()
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=starlette_app,
|
||||
host=self.settings.host,
|
||||
port=self.settings.port,
|
||||
log_level=self.settings.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
def ws_app(self) -> Starlette:
|
||||
"""Return an instance of the Websocket server app."""
|
||||
|
||||
async def handle_ws(websocket: WebSocket):
|
||||
async with websocket_server(
|
||||
websocket.scope, websocket.receive, websocket.send
|
||||
) as (ws_read_stream, ws_write_stream):
|
||||
await self._mcp_server.run(
|
||||
ws_read_stream,
|
||||
ws_write_stream,
|
||||
self._mcp_server.create_initialization_options(),
|
||||
raise_exceptions=self.settings.debug,
|
||||
)
|
||||
|
||||
# Create routes
|
||||
routes: list[Route | Mount] = []
|
||||
middleware: list[Middleware] = []
|
||||
required_scopes = []
|
||||
|
||||
# Set up auth if configured
|
||||
if self.settings.auth:
|
||||
required_scopes = self.settings.auth.required_scopes or []
|
||||
|
||||
# Add auth middleware if token verifier is available
|
||||
if self._token_verifier:
|
||||
middleware = [
|
||||
Middleware(
|
||||
AuthenticationMiddleware,
|
||||
backend=BearerAuthBackend(self._token_verifier),
|
||||
),
|
||||
Middleware(AuthContextMiddleware),
|
||||
]
|
||||
|
||||
# Add auth endpoints if auth server provider is configured
|
||||
if self._auth_server_provider:
|
||||
from mcp.server.auth.routes import create_auth_routes
|
||||
|
||||
routes.extend(
|
||||
create_auth_routes(
|
||||
provider=self._auth_server_provider,
|
||||
issuer_url=self.settings.auth.issuer_url,
|
||||
service_documentation_url=self.settings.auth.service_documentation_url,
|
||||
client_registration_options=self.settings.auth.client_registration_options,
|
||||
revocation_options=self.settings.auth.revocation_options,
|
||||
)
|
||||
)
|
||||
|
||||
# Set up routes with or without auth
|
||||
if self._token_verifier:
|
||||
# Determine resource metadata URL
|
||||
resource_metadata_url = None
|
||||
if self.settings.auth and self.settings.auth.resource_server_url:
|
||||
from pydantic import AnyHttpUrl
|
||||
|
||||
resource_metadata_url = AnyHttpUrl(
|
||||
str(self.settings.auth.resource_server_url).rstrip("/")
|
||||
+ "/.well-known/oauth-protected-resource"
|
||||
)
|
||||
|
||||
routes.append(
|
||||
Route(
|
||||
settings.websocket_path,
|
||||
endpoint=RequireAuthMiddleware(
|
||||
handle_ws, required_scopes, resource_metadata_url
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Auth is disabled, no wrapper needed
|
||||
routes.append(
|
||||
Route(
|
||||
settings.websocket_path,
|
||||
endpoint=handle_ws,
|
||||
)
|
||||
)
|
||||
|
||||
# Add protected resource metadata endpoint if configured as RS
|
||||
if self.settings.auth and self.settings.auth.resource_server_url:
|
||||
from mcp.server.auth.handlers.metadata import (
|
||||
ProtectedResourceMetadataHandler,
|
||||
)
|
||||
from mcp.server.auth.routes import cors_middleware
|
||||
from mcp.shared.auth import ProtectedResourceMetadata
|
||||
|
||||
protected_resource_metadata = ProtectedResourceMetadata(
|
||||
resource=self.settings.auth.resource_server_url,
|
||||
authorization_servers=[self.settings.auth.issuer_url],
|
||||
scopes_supported=self.settings.auth.required_scopes,
|
||||
)
|
||||
routes.append(
|
||||
Route(
|
||||
"/.well-known/oauth-protected-resource",
|
||||
endpoint=cors_middleware(
|
||||
ProtectedResourceMetadataHandler(
|
||||
protected_resource_metadata
|
||||
).handle,
|
||||
["GET", "OPTIONS"],
|
||||
),
|
||||
methods=["GET", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
routes.extend(self._custom_starlette_routes)
|
||||
|
||||
return Starlette(
|
||||
debug=self.settings.debug,
|
||||
routes=routes,
|
||||
middleware=middleware,
|
||||
lifespan=lambda app: self.session_manager.run(),
|
||||
)
|
||||
|
||||
def better_tool(
|
||||
self,
|
||||
name: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
annotations: ToolAnnotations | None = None,
|
||||
structured_output: bool | None = None,
|
||||
):
|
||||
"""Decorator to register a tool.
|
||||
TODO: Implement a better tool function decorator.
|
||||
"""
|
||||
# tool_mcp = self._tool_manager._tools
|
||||
# existing = tool_mcp.get(name)
|
||||
# if existing:
|
||||
# if self._tool_manager.warn_on_duplicate_tools:
|
||||
# logger.warning(f"Tool already exists: {tool.name}")
|
||||
# return existing
|
||||
# self._tools[tool.name] = tool
|
||||
# return tool
|
||||
@@ -2,7 +2,9 @@ import contextlib
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .__about__ import __version__
|
||||
from .app import MCP_MAP
|
||||
from .config import settings
|
||||
from .routers.helpers import router as helpers_router
|
||||
|
||||
|
||||
@@ -14,13 +16,21 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app = FastAPI(
|
||||
title=settings.app_title,
|
||||
description=settings.app_description,
|
||||
version=__version__,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {"message": "Welcome!"}
|
||||
return {
|
||||
"message": "Welcome!",
|
||||
"tools": list(MCP_MAP.keys()),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@@ -28,12 +38,16 @@ async def health():
|
||||
"""Check the health of the server and list available tools."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"tools": list(MCP_MAP.keys()),
|
||||
}
|
||||
|
||||
|
||||
app.include_router(helpers_router)
|
||||
if settings.enable_helpers_router:
|
||||
app.include_router(helpers_router)
|
||||
|
||||
for name, mcp in MCP_MAP.items():
|
||||
app.mount(f"/{name}/compatible", mcp.sse_app())
|
||||
app.mount(f"/{name}", mcp.streamable_http_app())
|
||||
if settings.enable_sse:
|
||||
app.mount(f"/{name}/compatible", mcp.sse_app())
|
||||
if settings.enable_streamable_http:
|
||||
app.mount(f"/{name}", mcp.streamable_http_app())
|
||||
if settings.enable_websocket:
|
||||
app.mount(f"/{name}/websocket", mcp.ws_app())
|
||||
|
||||
Reference in New Issue
Block a user