|
6 | 6 |
|
7 | 7 | import asyncio |
8 | 8 | import base64 |
| 9 | +import hashlib |
| 10 | +import hmac |
9 | 11 | import json |
10 | 12 | import logging |
| 13 | +from typing import Any |
11 | 14 |
|
12 | 15 | import click |
13 | 16 | from mcp.server import ServerRequestContext |
|
20 | 23 | Completion, |
21 | 24 | CompletionArgument, |
22 | 25 | CompletionContext, |
| 26 | + CreateMessageRequest, |
| 27 | + CreateMessageRequestParams, |
| 28 | + CreateMessageResult, |
| 29 | + ElicitRequest, |
| 30 | + ElicitRequestFormParams, |
| 31 | + ElicitResult, |
23 | 32 | EmbeddedResource, |
24 | 33 | EmptyResult, |
25 | 34 | ImageContent, |
| 35 | + InputRequest, |
| 36 | + InputRequiredResult, |
26 | 37 | JSONRPCMessage, |
| 38 | + ListRootsRequest, |
| 39 | + ListRootsResult, |
27 | 40 | PromptReference, |
28 | 41 | ResourceTemplateReference, |
29 | 42 | SamplingMessage, |
|
33 | 46 | TextResourceContents, |
34 | 47 | UnsubscribeRequestParams, |
35 | 48 | ) |
36 | | -from mcp.types.jsonrpc import MISSING_REQUIRED_CLIENT_CAPABILITY |
| 49 | +from mcp.types.jsonrpc import INVALID_PARAMS, MISSING_REQUIRED_CLIENT_CAPABILITY |
37 | 50 | from pydantic import BaseModel, Field |
38 | 51 |
|
39 | 52 | logger = logging.getLogger(__name__) |
@@ -333,6 +346,228 @@ async def test_missing_capability(ctx: Context) -> str: |
333 | 346 | return "Client declared sampling capability; proceeding." |
334 | 347 |
|
335 | 348 |
|
| 349 | +# SEP-2322 InputRequiredResult fixtures (multi-round-trip / ephemeral workflow) |
| 350 | + |
| 351 | +NAME_SCHEMA = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} |
| 352 | + |
| 353 | + |
| 354 | +def _name_elicitation(message: str = "What is your name?") -> ElicitRequest: |
| 355 | + return ElicitRequest(params=ElicitRequestFormParams(message=message, requested_schema=NAME_SCHEMA)) |
| 356 | + |
| 357 | + |
| 358 | +@mcp.tool() |
| 359 | +async def test_input_required_result_elicitation(ctx: Context) -> str | InputRequiredResult: |
| 360 | + """Tests InputRequiredResult with a single elicitation request""" |
| 361 | + responses = ctx.input_responses |
| 362 | + if responses and "user_name" in responses: |
| 363 | + answer = responses["user_name"] |
| 364 | + name = answer.content.get("name", "stranger") if isinstance(answer, ElicitResult) and answer.content else "?" |
| 365 | + return f"Hello, {name}!" |
| 366 | + return InputRequiredResult(input_requests={"user_name": _name_elicitation()}) |
| 367 | + |
| 368 | + |
| 369 | +@mcp.tool() |
| 370 | +async def test_input_required_result_sampling(ctx: Context) -> str | InputRequiredResult: |
| 371 | + """Tests InputRequiredResult with a single sampling request""" |
| 372 | + responses = ctx.input_responses |
| 373 | + if responses and "capital_question" in responses: |
| 374 | + answer = responses["capital_question"] |
| 375 | + text = answer.content.text if isinstance(answer, CreateMessageResult) and answer.content.type == "text" else "?" |
| 376 | + return f"Model said: {text}" |
| 377 | + return InputRequiredResult( |
| 378 | + input_requests={ |
| 379 | + "capital_question": CreateMessageRequest( |
| 380 | + params=CreateMessageRequestParams( |
| 381 | + messages=[ |
| 382 | + SamplingMessage( |
| 383 | + role="user", content=TextContent(type="text", text="What is the capital of France?") |
| 384 | + ) |
| 385 | + ], |
| 386 | + max_tokens=100, |
| 387 | + ) |
| 388 | + ) |
| 389 | + } |
| 390 | + ) |
| 391 | + |
| 392 | + |
| 393 | +@mcp.tool() |
| 394 | +async def test_input_required_result_list_roots(ctx: Context) -> str | InputRequiredResult: |
| 395 | + """Tests InputRequiredResult with a single roots/list request""" |
| 396 | + responses = ctx.input_responses |
| 397 | + if responses and "client_roots" in responses: |
| 398 | + answer = responses["client_roots"] |
| 399 | + count = len(answer.roots) if isinstance(answer, ListRootsResult) else 0 |
| 400 | + return f"Client exposed {count} root(s)." |
| 401 | + return InputRequiredResult(input_requests={"client_roots": ListRootsRequest()}) |
| 402 | + |
| 403 | + |
| 404 | +@mcp.tool() |
| 405 | +async def test_input_required_result_request_state(ctx: Context) -> str | InputRequiredResult: |
| 406 | + """Tests requestState round-tripping in the InputRequiredResult flow""" |
| 407 | + responses = ctx.input_responses |
| 408 | + if responses and "confirm" in responses and ctx.request_state == "request-state-nonce": |
| 409 | + return "state-ok: confirmation received" |
| 410 | + confirm = ElicitRequest( |
| 411 | + params=ElicitRequestFormParams( |
| 412 | + message="Please confirm", |
| 413 | + requested_schema={"type": "object", "properties": {"ok": {"type": "boolean"}}, "required": ["ok"]}, |
| 414 | + ) |
| 415 | + ) |
| 416 | + return InputRequiredResult(input_requests={"confirm": confirm}, request_state="request-state-nonce") |
| 417 | + |
| 418 | + |
| 419 | +@mcp.tool() |
| 420 | +async def test_input_required_result_multiple_inputs(ctx: Context) -> str | InputRequiredResult: |
| 421 | + """Tests InputRequiredResult carrying elicitation, sampling and roots requests together""" |
| 422 | + responses = ctx.input_responses |
| 423 | + if responses and {"user_name", "greeting", "client_roots"} <= responses.keys(): |
| 424 | + return "All inputs received." |
| 425 | + return InputRequiredResult( |
| 426 | + input_requests={ |
| 427 | + "user_name": _name_elicitation(), |
| 428 | + "greeting": CreateMessageRequest( |
| 429 | + params=CreateMessageRequestParams( |
| 430 | + messages=[ |
| 431 | + SamplingMessage(role="user", content=TextContent(type="text", text="Generate a greeting")) |
| 432 | + ], |
| 433 | + max_tokens=50, |
| 434 | + ) |
| 435 | + ), |
| 436 | + "client_roots": ListRootsRequest(), |
| 437 | + }, |
| 438 | + request_state="multiple-inputs", |
| 439 | + ) |
| 440 | + |
| 441 | + |
| 442 | +@mcp.tool() |
| 443 | +async def test_input_required_result_multi_round(ctx: Context) -> str | InputRequiredResult: |
| 444 | + """Tests a three-round InputRequiredResult flow with evolving requestState""" |
| 445 | + state = json.loads(ctx.request_state) if ctx.request_state else {"round": 0} |
| 446 | + responses = ctx.input_responses or {} |
| 447 | + |
| 448 | + if state["round"] == 0: |
| 449 | + return InputRequiredResult( |
| 450 | + input_requests={"step1": _name_elicitation("Step 1: What is your name?")}, |
| 451 | + request_state=json.dumps({"round": 1}), |
| 452 | + ) |
| 453 | + |
| 454 | + if state["round"] == 1 and "step1" in responses: |
| 455 | + step1 = responses["step1"] |
| 456 | + name = step1.content.get("name") if isinstance(step1, ElicitResult) and step1.content else None |
| 457 | + color_schema = {"type": "object", "properties": {"color": {"type": "string"}}, "required": ["color"]} |
| 458 | + return InputRequiredResult( |
| 459 | + input_requests={ |
| 460 | + "step2": ElicitRequest( |
| 461 | + params=ElicitRequestFormParams( |
| 462 | + message="Step 2: What is your favorite color?", requested_schema=color_schema |
| 463 | + ) |
| 464 | + ) |
| 465 | + }, |
| 466 | + request_state=json.dumps({"round": 2, "name": name}), |
| 467 | + ) |
| 468 | + |
| 469 | + if state["round"] == 2 and "step2" in responses: |
| 470 | + step2 = responses["step2"] |
| 471 | + color = step2.content.get("color") if isinstance(step2, ElicitResult) and step2.content else None |
| 472 | + return f"{state.get('name')} likes {color}." |
| 473 | + |
| 474 | + # Missing or out-of-order response: re-request from the start. |
| 475 | + return InputRequiredResult( |
| 476 | + input_requests={"step1": _name_elicitation("Step 1: What is your name?")}, |
| 477 | + request_state=json.dumps({"round": 1}), |
| 478 | + ) |
| 479 | + |
| 480 | + |
| 481 | +# Fixed key for the conformance fixture; a real server would derive or rotate this. |
| 482 | +_STATE_HMAC_KEY = b"everything-server-fixture-key" |
| 483 | + |
| 484 | + |
| 485 | +def _seal_state(payload: str) -> str: |
| 486 | + encoded = base64.urlsafe_b64encode(payload.encode()).decode() |
| 487 | + sig = hmac.new(_STATE_HMAC_KEY, encoded.encode(), hashlib.sha256).hexdigest() |
| 488 | + return f"{encoded}.{sig}" |
| 489 | + |
| 490 | + |
| 491 | +def _unseal_state(state: str) -> str: |
| 492 | + encoded, _, sig = state.partition(".") |
| 493 | + expected = hmac.new(_STATE_HMAC_KEY, encoded.encode(), hashlib.sha256).hexdigest() |
| 494 | + if not sig or not hmac.compare_digest(sig, expected): |
| 495 | + raise MCPError(code=INVALID_PARAMS, message="requestState failed integrity verification") |
| 496 | + return base64.urlsafe_b64decode(encoded).decode() |
| 497 | + |
| 498 | + |
| 499 | +@mcp.tool() |
| 500 | +async def test_input_required_result_tampered_state(ctx: Context) -> str | InputRequiredResult: |
| 501 | + """Tests that the server rejects a requestState that fails HMAC verification""" |
| 502 | + if ctx.request_state is None: |
| 503 | + confirm = ElicitRequest( |
| 504 | + params=ElicitRequestFormParams( |
| 505 | + message="Please confirm", |
| 506 | + requested_schema={"type": "object", "properties": {"ok": {"type": "boolean"}}, "required": ["ok"]}, |
| 507 | + ) |
| 508 | + ) |
| 509 | + return InputRequiredResult(input_requests={"confirm": confirm}, request_state=_seal_state("round-1")) |
| 510 | + payload = _unseal_state(ctx.request_state) |
| 511 | + return f"state-ok: {payload}" |
| 512 | + |
| 513 | + |
| 514 | +@mcp.tool() |
| 515 | +async def test_input_required_result_capabilities(ctx: Context) -> InputRequiredResult: |
| 516 | + """Tests that inputRequests only include methods the client declared support for""" |
| 517 | + caps = ctx.client_capabilities |
| 518 | + requests: dict[str, InputRequest] = {} |
| 519 | + if caps is None or caps.sampling is not None: |
| 520 | + requests["sample"] = CreateMessageRequest( |
| 521 | + params=CreateMessageRequestParams( |
| 522 | + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Say hello"))], |
| 523 | + max_tokens=50, |
| 524 | + ) |
| 525 | + ) |
| 526 | + if caps is None or caps.elicitation is not None: |
| 527 | + requests["ask"] = _name_elicitation() |
| 528 | + return InputRequiredResult(input_requests=requests, request_state="capability-gated") |
| 529 | + |
| 530 | + |
| 531 | +# SEP-1613 / SEP-2106 JSON Schema 2020-12 fixture: a tool whose inputSchema carries |
| 532 | +# the full set of 2020-12 keywords the conformance scenario asserts on. |
| 533 | + |
| 534 | +JSON_SCHEMA_2020_12_INPUT_SCHEMA: dict[str, Any] = { |
| 535 | + "$schema": "https://json-schema.org/draft/2020-12/schema", |
| 536 | + "type": "object", |
| 537 | + "$defs": { |
| 538 | + "address": { |
| 539 | + "$anchor": "addressDef", |
| 540 | + "type": "object", |
| 541 | + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, |
| 542 | + } |
| 543 | + }, |
| 544 | + "properties": { |
| 545 | + "name": {"type": "string"}, |
| 546 | + "address": {"$ref": "#/$defs/address"}, |
| 547 | + "contactMethod": {"type": "string", "enum": ["phone", "email"]}, |
| 548 | + "phone": {"type": "string"}, |
| 549 | + "email": {"type": "string"}, |
| 550 | + }, |
| 551 | + "allOf": [{"anyOf": [{"required": ["phone"]}, {"required": ["email"]}]}], |
| 552 | + "if": {"properties": {"contactMethod": {"const": "phone"}}, "required": ["contactMethod"]}, |
| 553 | + "then": {"required": ["phone"]}, |
| 554 | + "else": {"required": ["email"]}, |
| 555 | + "additionalProperties": False, |
| 556 | +} |
| 557 | + |
| 558 | + |
| 559 | +@mcp.tool(name="json_schema_2020_12_tool") |
| 560 | +def json_schema_2020_12_tool() -> str: |
| 561 | + """Tests JSON Schema 2020-12 keyword preservation in tools/list (inputSchema installed below).""" |
| 562 | + return "json_schema_2020_12_tool" |
| 563 | + |
| 564 | + |
| 565 | +# TODO(felix): replace with a public input_schema= override once MCPServer.tool() grows one. |
| 566 | +mcp._tool_manager._tools["json_schema_2020_12_tool"].parameters = ( # pyright: ignore[reportPrivateUsage] |
| 567 | + JSON_SCHEMA_2020_12_INPUT_SCHEMA |
| 568 | +) |
| 569 | + |
| 570 | + |
336 | 571 | @mcp.tool() |
337 | 572 | async def test_reconnection(ctx: Context) -> str: |
338 | 573 | """Tests SSE polling by closing stream mid-call (SEP-1699)""" |
|
0 commit comments