Skip to content

Commit 94de7b1

Browse files
committed
Fix MCP server mypy typing
1 parent 7a9441e commit 94de7b1

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

clickadvisor/mcp_server/server.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import json
55
from pathlib import Path
6-
from typing import Any
6+
from typing import Any, cast
77

88
from mcp.server import Server
99
from mcp.server.stdio import stdio_server
@@ -13,12 +13,13 @@
1313
from clickadvisor.core.models import QueryContext
1414
from clickadvisor.core.pipeline import AnalysisPipeline
1515
from clickadvisor.core.version import detect_version
16+
from clickadvisor.rules.base import Rule
1617
from clickadvisor.rules.registry import get_applicable_rules
1718

1819
server = Server("clickadvisor")
1920

2021

21-
@server.list_tools()
22+
@server.list_tools() # type: ignore[no-untyped-call, untyped-decorator]
2223
async def list_tools() -> list[Tool]:
2324
return [
2425
Tool(
@@ -121,7 +122,7 @@ async def list_tools() -> list[Tool]:
121122
]
122123

123124

124-
@server.list_prompts()
125+
@server.list_prompts() # type: ignore[no-untyped-call, untyped-decorator]
125126
async def list_prompts() -> list[Prompt]:
126127
return [
127128
Prompt(
@@ -156,7 +157,7 @@ async def list_prompts() -> list[Prompt]:
156157
]
157158

158159

159-
@server.get_prompt()
160+
@server.get_prompt() # type: ignore[no-untyped-call, untyped-decorator]
160161
async def get_prompt(name: str, arguments: dict[str, Any]) -> GetPromptResult:
161162
sql = str(arguments.get("sql", ""))
162163
ch_version = str(arguments.get("ch_version", ""))
@@ -193,7 +194,7 @@ async def get_prompt(name: str, arguments: dict[str, Any]) -> GetPromptResult:
193194
)
194195

195196

196-
@server.call_tool()
197+
@server.call_tool() # type: ignore[untyped-decorator]
197198
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
198199
if name == "analyze_query":
199200
return await _analyze_query(arguments)
@@ -220,7 +221,7 @@ async def _analyze_query(arguments: dict[str, Any]) -> list[TextContent]:
220221
ch_version=str(ch_version) if ch_version is not None else None,
221222
)
222223

223-
rules = get_applicable_rules(context.ch_version)
224+
rules = _get_applicable_rules(context.ch_version)
224225
retrieval_advisor = _build_retrieval_advisor()
225226
pipeline = AnalysisPipeline(
226227
rules=rules,
@@ -242,7 +243,7 @@ async def _analyze_query_json(arguments: dict[str, Any]) -> list[TextContent]:
242243
schema_ddl=str(schema_ddl) if schema_ddl is not None else None,
243244
ch_version=str(ch_version) if ch_version is not None else None,
244245
)
245-
rules = get_applicable_rules(context.ch_version)
246+
rules = _get_applicable_rules(context.ch_version)
246247
pipeline = AnalysisPipeline(rules=rules)
247248
report = pipeline.run(context)
248249

@@ -271,7 +272,7 @@ async def _analyze_query_json(arguments: dict[str, Any]) -> list[TextContent]:
271272

272273
async def _list_rules(arguments: dict[str, Any]) -> list[TextContent]:
273274
tier_filter = str(arguments.get("tier", "all"))
274-
rules = get_applicable_rules(None)
275+
rules = _get_applicable_rules(None)
275276

276277
lines = ["# Правила оптимизации ClickAdvisor\n"]
277278
tier_order = {"1A": 0, "1B": 1, "1C": 2, "detector": 3, "rag": 4}
@@ -338,6 +339,10 @@ def _build_retrieval_advisor() -> Any | None:
338339
return None
339340

340341

342+
def _get_applicable_rules(ch_version: str | None) -> list[Rule]:
343+
return cast(list[Rule], get_applicable_rules(ch_version))
344+
345+
341346
async def main() -> None:
342347
async with stdio_server() as (read_stream, write_stream):
343348
await server.run(read_stream, write_stream, server.create_initialization_options())

0 commit comments

Comments
 (0)