diff --git a/doris_mcp_server/main.py b/doris_mcp_server/main.py index 7a27e6e..9bad8bc 100644 --- a/doris_mcp_server/main.py +++ b/doris_mcp_server/main.py @@ -518,6 +518,7 @@ async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config. from starlette.routing import Route from starlette.responses import JSONResponse, Response from starlette.types import Scope + from starlette.middleware.cors import CORSMiddleware # Create session manager session_manager = StreamableHTTPSessionManager( @@ -600,6 +601,15 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: ], lifespan=lifespan, ) + + # Add CORS middleware allowing all origins + starlette_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) # Custom ASGI app that handles both /mcp and /mcp/ without redirects async def mcp_app(scope, receive, send): @@ -663,6 +673,18 @@ async def mcp_app(scope, receive, send): await response(scope, receive, send) return + # Handle CORS preflight OPTIONS requests + if method == "OPTIONS": + from starlette.responses import Response + request_headers = headers.get(b'access-control-request-headers', b'').decode('utf-8') + response = Response("", status_code=204) + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "GET, POST, DELETE, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = request_headers or "*" + response.headers["Access-Control-Max-Age"] = "86400" + await response(scope, receive, send) + return + # Handle Dify compatibility for GET requests if method == "GET": accept_header = headers.get(b'accept', b'').decode('utf-8') @@ -687,7 +709,23 @@ async def mcp_app(scope, receive, send): scope["headers"] = new_headers self.logger.info(f"Modified Accept header to: {new_value}") - await session_manager.handle_request(scope, receive, send) + # Wrap send to inject CORS headers into session_manager responses + async def send_with_cors(message): + if message.get("type") == "http.response.start": + headers = list(message.get("headers", [])) + cors_headers = [ + (b"access-control-allow-origin", b"*"), + (b"access-control-allow-credentials", b"true"), + (b"access-control-expose-headers", b"*"), + ] + for name, value in cors_headers: + # Avoid duplicate headers + if not any(h[0] == name for h in headers): + headers.append((name, value)) + message["headers"] = headers + await send(message) + + await session_manager.handle_request(scope, receive, send_with_cors) return # 404 for other paths diff --git a/doris_mcp_server/multiworker_app.py b/doris_mcp_server/multiworker_app.py index b4368f6..41bbe3f 100644 --- a/doris_mcp_server/multiworker_app.py +++ b/doris_mcp_server/multiworker_app.py @@ -212,6 +212,7 @@ def permissive_check_generic(cls, params, elen): from starlette.applications import Starlette from starlette.routing import Route from starlette.responses import JSONResponse, Response +from starlette.middleware.cors import CORSMiddleware # Import Doris MCP components from .tools.tools_manager import DorisToolsManager @@ -614,6 +615,15 @@ async def mcp_asgi_app(scope, receive, send): lifespan=lifespan ) +# Add CORS middleware allowing all origins +basic_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + # Create main ASGI app that routes between basic app and MCP async def app(scope, receive, send): """Main ASGI app that routes requests"""