diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 1c9ba19ca..b218b4ca1 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -7,67 +7,166 @@ on: push: branches: [ main ] +env: + JAVA_VERSION: '21' + JAVA_DISTRIBUTION: temurin + NODE_VERSION: '22' + + concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} # Cancel only when the run is NOT on `main` branch cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: - run-conformance: - runs-on: ${{ matrix.os }} - name: Run Conformance Tests on ${{ matrix.os }} + server: + runs-on: ubuntu-latest + name: Conformance Server Tests + timeout-minutes: 20 + + steps: + - uses: actions/checkout@v6 + + - name: Set up JDK + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JAVA_VERSION }} + distribution: ${{ env.JAVA_DISTRIBUTION }} + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v5 + with: + cache-read-only: ${{ github.ref != 'refs/heads/main' }} + + - name: Build + run: ./gradlew :conformance-test:installDist + + - name: Start server + run: | + MCP_PORT=3001 conformance-test/build/install/conformance-test/bin/conformance-test & + for i in $(seq 1 30); do + if curl -s -o /dev/null http://localhost:3001/mcp; then + echo "Server is ready" + break + fi + sleep 1 + done + + - name: Run conformance tests + uses: modelcontextprotocol/conformance@v0.1.15 + with: + mode: server + url: http://localhost:3001/mcp + suite: active + node-version: ${{ env.NODE_VERSION }} + expected-failures: ./conformance-test/conformance-baseline.yml + + client: + runs-on: ubuntu-latest + name: "Conformance Client Tests: ${{ matrix.scenario }}" timeout-minutes: 20 - env: - JAVA_OPTS: "-Xmx8g -Dfile.encoding=UTF-8 -Djava.awt.headless=true -Dkotlin.daemon.jvm.options=-Xmx6g" strategy: fail-fast: false matrix: - include: - - os: ubuntu-latest - max-workers: 3 - - os: windows-latest - max-workers: 3 - - os: macos-latest - max-workers: 2 + scenario: + - initialize + - tools_call + - elicitation-sep1034-client-defaults + - sse-retry + + steps: + - uses: actions/checkout@v6 + + - name: Set up JDK + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JAVA_VERSION }} + distribution: ${{ env.JAVA_DISTRIBUTION }} + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v5 + with: + cache-read-only: ${{ github.ref != 'refs/heads/main' }} + + - name: Build + run: ./gradlew :conformance-test:installDist + + - name: Run conformance tests + uses: modelcontextprotocol/conformance@v0.1.15 + with: + mode: client + command: conformance-test/build/install/conformance-test/bin/conformance-client + scenario: ${{ matrix.scenario }} + node-version: ${{ env.NODE_VERSION }} + expected-failures: ./conformance-test/conformance-baseline.yml + + auth: + runs-on: ubuntu-latest + name: Conformance Auth Tests + timeout-minutes: 20 steps: - uses: actions/checkout@v6 - - name: Set up JDK 21 + - name: Set up JDK uses: actions/setup-java@v5 with: - java-version: '21' - distribution: 'temurin' + java-version: ${{ env.JAVA_VERSION }} + distribution: ${{ env.JAVA_DISTRIBUTION }} - - name: Setup Node.js - uses: actions/setup-node@v6 + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v5 with: - node-version: '22' # increase only after https://github.com/nodejs/node/issues/56645 will be fixed + cache-read-only: ${{ github.ref != 'refs/heads/main' }} - - name: Setup Conformance Tests - working-directory: conformance-test - run: |- - npm install -g @modelcontextprotocol/conformance@0.1.8 + - name: Build + run: ./gradlew :conformance-test:installDist + + - name: Run conformance tests + uses: modelcontextprotocol/conformance@v0.1.15 + with: + mode: client + command: conformance-test/build/install/conformance-test/bin/conformance-client + suite: auth + node-version: ${{ env.NODE_VERSION }} + expected-failures: ./conformance-test/conformance-baseline.yml + + auth-scenarios: + runs-on: ubuntu-latest + name: "Conformance Auth Scenario: ${{ matrix.scenario }}" + timeout-minutes: 20 + + strategy: + fail-fast: false + matrix: + scenario: + - auth/client-credentials-jwt + - auth/client-credentials-basic + - auth/cross-app-access-complete-flow + + steps: + - uses: actions/checkout@v6 + + - name: Set up JDK + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JAVA_VERSION }} + distribution: ${{ env.JAVA_DISTRIBUTION }} - name: Setup Gradle uses: gradle/actions/setup-gradle@v5 with: - add-job-summary: 'always' cache-read-only: ${{ github.ref != 'refs/heads/main' }} - gradle-home-cache-includes: | - caches - notifications - sdks - ../.konan/** - - - name: Run Conformance Tests - run: |- - ./gradlew :conformance-test:test --no-daemon --max-workers ${{ matrix.max-workers }} - - - name: Upload Conformance Results - if: always() - uses: actions/upload-artifact@v7 + + - name: Build + run: ./gradlew :conformance-test:installDist + + - name: Run conformance tests + uses: modelcontextprotocol/conformance@v0.1.15 with: - name: conformance-results-${{ matrix.os }} - path: conformance-test/results/ + mode: client + command: conformance-test/build/install/conformance-test/bin/conformance-client + scenario: ${{ matrix.scenario }} + node-version: ${{ env.NODE_VERSION }} + expected-failures: ./conformance-test/conformance-baseline.yml diff --git a/build.gradle.kts b/build.gradle.kts index f4ddd4451..72c55539b 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -23,12 +23,15 @@ dependencies { subprojects { apply(plugin = "org.jlleitschuh.gradle.ktlint") apply(plugin = "org.jetbrains.kotlinx.kover") - apply(plugin = "dev.detekt") - detekt { - config = files("$rootDir/config/detekt/detekt.yml") - buildUponDefaultConfig = true - failOnSeverity.set(FailOnSeverity.Error) + if (name != "conformance-test" && name != "docs") { + apply(plugin = "dev.detekt") + + detekt { + config = files("$rootDir/config/detekt/detekt.yml") + buildUponDefaultConfig = true + failOnSeverity.set(FailOnSeverity.Error) + } } } diff --git a/conformance-test/.gitignore b/conformance-test/.gitignore new file mode 100644 index 000000000..6628455c0 --- /dev/null +++ b/conformance-test/.gitignore @@ -0,0 +1 @@ +/results/ diff --git a/conformance-test/README.md b/conformance-test/README.md new file mode 100644 index 000000000..d14fcfef3 --- /dev/null +++ b/conformance-test/README.md @@ -0,0 +1,130 @@ +# MCP Conformance Tests + +Conformance tests for the Kotlin MCP SDK. Uses the external +[`@modelcontextprotocol/conformance`](https://www.npmjs.com/package/@modelcontextprotocol/conformance) +runner (pinned to **0.1.15**) to validate compliance with the MCP specification. + +## Prerequisites + +- **JDK 17+** +- **Node.js 18+** and `npx` (for the conformance runner) +- **curl** (used to poll server readiness) + +## Quick Start + +Run **all** suites (server, client core, client auth) from the project root: + +```bash +./conformance-test/run-conformance.sh all +``` + +## Commands + +``` +./conformance-test/run-conformance.sh [extra-args...] +``` + +| Command | What it does | +|---------------|--------------------------------------------------------------------------------------| +| `list` | [List scenarios available in MCP Conformance Test Framework][list-scenarios-command] | +| `server` | Starts the Ktor conformance server, runs the server test suite against it | +| `client` | Runs the client test suite (`initialize`, `tools_call`, `elicitation`, `sse-retry`) | +| `client-auth` | Runs the client auth test suite (20 OAuth scenarios) | +| `all` | Runs all three suites sequentially | + +Any `[extra-args]` are forwarded to the conformance runner (e.g. `--verbose`). + +## What the Script Does + +1. **Builds** the module via `./gradlew :conformance-test:installDist` +2. For `server` — starts the conformance server on `localhost:3001`, polls until ready +3. Invokes `npx @modelcontextprotocol/conformance@0.1.15` with the appropriate arguments +4. Saves results to `conformance-test/results//` +5. Cleans up the server process on exit +6. Exits non-zero if any suite fails + +## Environment Variables + +| Variable | Default | Description | +|------------|---------|---------------------------------| +| `MCP_PORT` | `3001` | Port for the conformance server | + +## Project Structure + +``` +conformance-test/ +├── run-conformance.sh # Single entry point script +├── conformance-baseline.yml # Expected failures for known SDK limitations +└── src/main/kotlin/.../conformance/ + ├── ConformanceServer.kt # Ktor server entry point (StreamableHTTP, DNS rebinding, EventStore) + ├── ConformanceClient.kt # Scenario-based client entry point (MCP_CONFORMANCE_SCENARIO routing) + ├── ConformanceTools.kt # 18 tool registrations + ├── ConformanceResources.kt # 5 resource registrations (static, binary, template, watched, dynamic) + ├── ConformancePrompts.kt # 5 prompt registrations (simple, args, image, embedded, dynamic) + ├── ConformanceCompletions.kt # completion/complete handler + ├── InMemoryEventStore.kt # EventStore impl for SSE resumability (SEP-1699) + └── auth/ # OAuth client for 20 auth scenarios + ├── registration.kt # Scenario handler registration + ├── utils.kt # Shared utilities: JSON instance, constants, extractOrigin() + ├── discovery.kt # Protected Resource Metadata + AS Metadata discovery + ├── pkce.kt # PKCE code verifier/challenge generation + AS capability check + ├── tokenExchange.kt # Token endpoint interaction (exchange code, error handling) + ├── authCodeFlow.kt # Main Authorization Code flow handler (runAuthClient + interceptor) + ├── scopeHandling.kt # Scope selection strategy + step-up 403 handling + ├── clientRegistration.kt # Client registration logic (pre-reg, CIMD, dynamic) + ├── JWTScenario.kt # Client Credentials JWT scenario + ├── basicScenario.kt # Client Credentials Basic scenario + └── crossAppAccessScenario.kt # Cross-App Access (SEP-990) scenario +``` + +## Test Suites + +### Server Suite + +Tests the conformance server against all server scenarios: + +| Category | Scenarios | +|-------------|-------------------------------------------------------------------------------------------------------------------------------------| +| Lifecycle | initialize, ping | +| Tools | text, image, audio, embedded, multiple, progress, logging, error, sampling, elicitation, dynamic, reconnection, JSON Schema 2020-12 | +| Resources | list, read-text, read-binary, templates, subscribe, dynamic | +| Prompts | simple, with-args, with-image, with-embedded-resource, dynamic | +| Completions | complete | +| Security | DNS rebinding protection | + +### Client Core Suite + +| Scenario | Description | +|---------------------------------------|-----------------------------------------------| +| `initialize` | Connect, list tools, close | +| `tools_call` | Connect, call `add_numbers(a=5, b=3)`, close | +| `elicitation-sep1034-client-defaults` | Elicitation with `applyDefaults` capability | +| `sse-retry` | Call `test_reconnection`, verify reconnection | + +### Client Auth Suite + +17 OAuth Authorization Code scenarios + 2 Client Credentials scenarios (`jwt`, `basic`) + 1 Cross-App Access scenario = 20 total. + +> [!NOTE] +> Auth scenarios are implemented using Ktor's `HttpClient` plugins (`HttpSend` interceptor, +> `ktor-client-auth`) as a standalone OAuth client. They do not use the SDK's built-in auth support. + +## Known SDK Limitations + +8 scenarios are expected to fail due to current SDK limitations (tracked in [ +`conformance-baseline.yml`](conformance-baseline.yml). + +| Scenario | Suite | Root Cause | +|---------------------------------------|--------|--------------------------------------------------------------------------------------------------------------------------------------------------------| +| `tools-call-with-logging` | server | Notifications from tool handlers have no `relatedRequestId`; transport routes them to the standalone SSE stream instead of the request-specific stream | +| `tools-call-with-progress` | server | *(same as above)* | +| `tools-call-sampling` | server | *(same as above)* | +| `tools-call-elicitation` | server | *(same as above)* | +| `elicitation-sep1034-defaults` | server | *(same as above)* | +| `elicitation-sep1330-enums` | server | *(same as above)* | +| `resources-templates-read` | server | SDK does not implement `addResourceTemplate()` with URI pattern matching; resources are looked up by exact URI | +| `elicitation-sep1034-client-defaults` | client | SDK does not fill in `default` values from the elicitation request schema before sending the response | + +These failures reveal SDK gaps and are intentionally not fixed in this module. + +[list-scenarios-command]: https://github.com/modelcontextprotocol/conformance/tree/main?tab=readme-ov-file#list-available-scenarios diff --git a/conformance-test/build.gradle.kts b/conformance-test/build.gradle.kts index 1877d0974..a96f4fb56 100644 --- a/conformance-test/build.gradle.kts +++ b/conformance-test/build.gradle.kts @@ -1,39 +1,34 @@ -import org.gradle.api.tasks.testing.logging.TestExceptionFormat - plugins { kotlin("jvm") + application } -dependencies { - testImplementation(project(":kotlin-sdk")) - testImplementation(project(":test-utils")) - testImplementation(kotlin("test")) - testImplementation(libs.kotlin.logging) - testImplementation(libs.ktor.client.cio) - testImplementation(libs.ktor.server.cio) - testImplementation(libs.ktor.server.websockets) - testRuntimeOnly(libs.slf4j.simple) +application { + mainClass.set("io.modelcontextprotocol.kotlin.sdk.conformance.ConformanceServerKt") } -tasks.test { - useJUnitPlatform() +tasks.register("conformanceClientScripts") { + mainClass.set("io.modelcontextprotocol.kotlin.sdk.conformance.ConformanceClientKt") + applicationName = "conformance-client" + outputDir = tasks.named("startScripts").get().outputDir + classpath = tasks.named("jar").get().outputs.files + configurations.named("runtimeClasspath").get() +} - testLogging { - events("passed", "skipped", "failed") - showStandardStreams = true - showExceptions = true - showCauses = true - showStackTraces = true - exceptionFormat = TestExceptionFormat.FULL - } +tasks.named("installDist") { + dependsOn("conformanceClientScripts") +} - doFirst { - systemProperty("test.classpath", classpath.asPath) +tasks.named("clean") { + delete("results") +} - println("\n" + "=".repeat(60)) - println("MCP CONFORMANCE TESTS") - println("=".repeat(60)) - println("These tests validate compliance with the MCP specification.") - println("=".repeat(60) + "\n") - } +dependencies { + implementation(project(":kotlin-sdk")) + implementation(libs.ktor.server.cio) + implementation(libs.ktor.server.content.negotiation) + implementation(libs.ktor.serialization) + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.auth) + implementation(libs.kotlin.logging) + runtimeOnly(libs.slf4j.simple) } diff --git a/conformance-test/conformance-baseline.yml b/conformance-test/conformance-baseline.yml new file mode 100644 index 000000000..cc06a389b --- /dev/null +++ b/conformance-test/conformance-baseline.yml @@ -0,0 +1,13 @@ +# Conformance test baseline - expected failures +# Add entries here as tests are identified as known SDK limitations +server: + - tools-call-with-logging + - tools-call-with-progress + - tools-call-sampling + - tools-call-elicitation + - elicitation-sep1034-defaults + - elicitation-sep1330-enums + - resources-templates-read + +client: + - elicitation-sep1034-client-defaults diff --git a/conformance-test/detekt-baseline.xml b/conformance-test/detekt-baseline.xml deleted file mode 100644 index fe120f523..000000000 --- a/conformance-test/detekt-baseline.xml +++ /dev/null @@ -1,9 +0,0 @@ - - - - - BracesOnWhenStatements:ConformanceServer.kt:HttpServerTransport$when - BracesOnWhenStatements:ConformanceTest.kt:ConformanceTest$when - ForbiddenComment:ConformanceTest.kt:ConformanceTest.Companion$// TODO: Fix - - diff --git a/conformance-test/run-conformance.sh b/conformance-test/run-conformance.sh new file mode 100755 index 000000000..8f15ac0d1 --- /dev/null +++ b/conformance-test/run-conformance.sh @@ -0,0 +1,194 @@ +#!/bin/bash +# Script to run MCP conformance tests for the Kotlin SDK. +# +# Usage: ./conformance-test/run-conformance.sh [extra-args...] +# Commands: server | client | client-auth | all + +set -uo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" || exit 1; pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." || exit 1; pwd)" + +CONFORMANCE_VERSION="0.1.15" +PORT="${MCP_PORT:-3001}" +SERVER_URL="http://localhost:${PORT}/mcp" +RESULTS_DIR="$SCRIPT_DIR/results" +SERVER_DIST="$SCRIPT_DIR/build/install/conformance-test/bin/conformance-test" +CLIENT_DIST="$SCRIPT_DIR/build/install/conformance-test/bin/conformance-client" + +SERVER_PID="" + +# shellcheck disable=SC2317 +cleanup() { + if [ -n "$SERVER_PID" ] && kill -0 "$SERVER_PID" 2>/dev/null; then + echo "Stopping server (PID: $SERVER_PID)..." + kill "$SERVER_PID" 2>/dev/null || true + wait "$SERVER_PID" 2>/dev/null || true + fi +} +trap cleanup EXIT + +build() { + echo "Building conformance-test distributions..." + cd "$PROJECT_ROOT" || return 1 + ./gradlew :conformance-test:installDist --quiet + cd "$SCRIPT_DIR" || return 1 + echo "Build complete." +} + +start_server() { + echo "Starting conformance server on port $PORT..." + MCP_PORT="$PORT" "$SERVER_DIST" & + SERVER_PID=$! + + echo "Waiting for server to be ready..." + local retries=0 + local max_retries=30 + while ! curl -sf "$SERVER_URL" > /dev/null 2>&1; do + retries=$((retries + 1)) + if [ "$retries" -ge "$max_retries" ]; then + echo "ERROR: Server failed to start after $max_retries attempts" + return 1 + fi + sleep 0.5 + done + echo "Server is ready (PID: $SERVER_PID)." +} + +stop_server() { + if [ -n "$SERVER_PID" ] && kill -0 "$SERVER_PID" 2>/dev/null; then + echo "Stopping server (PID: $SERVER_PID)..." + kill "$SERVER_PID" 2>/dev/null || true + wait "$SERVER_PID" 2>/dev/null || true + SERVER_PID="" + fi +} + +run_list_scenarios() { + local output_dir="$RESULTS_DIR/list" + mkdir -p "$output_dir" + echo "" + echo "==========================================" + echo " List Available Scenarios" + echo "==========================================" + local rc=0 + npx "@modelcontextprotocol/conformance@$CONFORMANCE_VERSION" list \ + "$@" > "$output_dir/scenarios.txt" || rc=$? + + cat "$output_dir/scenarios.txt" + return $rc +} + +run_server_suite() { + local output_dir="$RESULTS_DIR/server" + mkdir -p "$output_dir" + echo "" + echo "==========================================" + echo " Running SERVER conformance tests" + echo "==========================================" + start_server || return 1 + local rc=0 + npx "@modelcontextprotocol/conformance@$CONFORMANCE_VERSION" server \ + --url "$SERVER_URL" \ + --output-dir "$output_dir" \ + --expected-failures "$SCRIPT_DIR/conformance-baseline.yml" \ + "$@" || rc=$? + stop_server + return $rc +} + +run_client_suite() { + local output_dir="$RESULTS_DIR/client" + mkdir -p "$output_dir" + echo "" + echo "==========================================" + echo " Running CLIENT conformance tests" + echo "==========================================" + local scenarios=("initialize" "tools_call" "elicitation-sep1034-client-defaults" "sse-retry") + local rc=0 + for scenario in "${scenarios[@]}"; do + npx "@modelcontextprotocol/conformance@$CONFORMANCE_VERSION" client \ + --command "$CLIENT_DIST" \ + --scenario "$scenario" \ + --output-dir "$output_dir" \ + --expected-failures "$SCRIPT_DIR/conformance-baseline.yml" \ + "$@" || rc=$? + done + return $rc +} + +run_client_auth_suite() { + local output_dir="$RESULTS_DIR/client-auth" + mkdir -p "$output_dir" + echo "" + echo "==========================================" + echo " Running CLIENT (auth) conformance tests" + echo "==========================================" + local rc=0 + npx "@modelcontextprotocol/conformance@$CONFORMANCE_VERSION" client \ + --command "$CLIENT_DIST" \ + --suite auth \ + --output-dir "$output_dir" \ + --expected-failures "$SCRIPT_DIR/conformance-baseline.yml" \ + "$@" || rc=$? + + local extra_scenarios=("auth/client-credentials-jwt" "auth/client-credentials-basic" "auth/cross-app-access-complete-flow") + for scenario in "${extra_scenarios[@]}"; do + npx "@modelcontextprotocol/conformance@$CONFORMANCE_VERSION" client \ + --command "$CLIENT_DIST" \ + --scenario "$scenario" \ + --output-dir "$output_dir" \ + --expected-failures "$SCRIPT_DIR/conformance-baseline.yml" \ + "$@" || rc=$? + done + return $rc +} + +# ============================================================================ +# Main +# ============================================================================ + +COMMAND="${1:-}" +shift 2>/dev/null || true + +if [ -z "$COMMAND" ]; then + echo "Usage: $0 [extra-args...]" + echo "Commands: list | server | client | client-auth | all" + exit 1 +fi + +build + +EXIT_CODE=0 + +case "$COMMAND" in + list) + run_list_scenarios "$@" || EXIT_CODE=1 + ;; + server) + run_server_suite "$@" || EXIT_CODE=1 + ;; + client) + run_client_suite "$@" || EXIT_CODE=1 + ;; + client-auth) + run_client_auth_suite "$@" || EXIT_CODE=1 + ;; + all) + run_server_suite "$@" || EXIT_CODE=1 + run_client_suite "$@" || EXIT_CODE=1 + run_client_auth_suite "$@" || EXIT_CODE=1 + ;; + *) + echo "Unknown command: $COMMAND" + echo "Commands: list | server | client | client-auth | all" + exit 1 + ;; +esac + +echo "" +echo "==========================================" +echo " Results saved to: $RESULTS_DIR" +echo "==========================================" + +exit $EXIT_CODE diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt new file mode 100644 index 000000000..dba24ff2b --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt @@ -0,0 +1,199 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.conformance.auth.registerAuthScenarios +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.ElicitResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlin.system.exitProcess +import kotlin.time.Duration.Companion.seconds + +private val logger = KotlinLogging.logger {} + +typealias ScenarioHandler = suspend (serverUrl: String) -> Unit + +internal val scenarioHandlers = mutableMapOf() + +// ============================================================================ +// Main entry point +// ============================================================================ + +fun main(args: Array) { + val scenarioName = System.getenv("MCP_CONFORMANCE_SCENARIO") + val serverUrl = args.lastOrNull() + + // Register all scenario handlers + registerCoreScenarios() + registerAuthScenarios() + + if (scenarioName == null || serverUrl == null) { + logger.error { "Usage: MCP_CONFORMANCE_SCENARIO= conformance-client " } + logger.error { "\nThe MCP_CONFORMANCE_SCENARIO env var is set automatically by the conformance runner." } + logger.error { "\nAvailable scenarios:" } + for (name in scenarioHandlers.keys.sorted()) { + logger.error { " - $name" } + } + exitProcess(1) + } + + val handler = scenarioHandlers[scenarioName] + if (handler == null) { + logger.error { "Unknown scenario: $scenarioName" } + logger.error { "\nAvailable scenarios:" } + for (name in scenarioHandlers.keys.sorted()) { + logger.error { " - $name" } + } + exitProcess(1) + } + + try { + runBlocking { + handler(serverUrl) + } + exitProcess(0) + } catch (e: Exception) { + logger.error(e) { "Error: ${e.message}" } + exitProcess(1) + } +} + +// ============================================================================ +// Shared HTTP client factory +// ============================================================================ + +private fun createHttpClient(): HttpClient = HttpClient(CIO) { + install(SSE) + install(HttpTimeout) { + requestTimeoutMillis = 30.seconds.inWholeMilliseconds + } +} + +// ============================================================================ +// Basic scenarios (initialize, tools_call) +// ============================================================================ + +private suspend fun runBasicClient(serverUrl: String) { + createHttpClient().use { httpClient -> + val transport = StreamableHttpClientTransport(httpClient, serverUrl) + val client = Client( + clientInfo = Implementation("test-client", "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + client.connect(transport) + client.close() + } +} + +private suspend fun runToolsCallClient(serverUrl: String) { + createHttpClient().use { httpClient -> + val transport = StreamableHttpClientTransport(httpClient, serverUrl) + val client = Client( + clientInfo = Implementation("test-client", "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + client.connect(transport) + + val tools = client.listTools() + val addTool = tools.tools.find { it.name == "add_numbers" } + if (addTool != null) { + client.callTool( + CallToolRequest( + CallToolRequestParams( + name = "add_numbers", + arguments = buildJsonObject { + put("a", 5) + put("b", 3) + }, + ), + ), + ) + } + + client.close() + } +} + +// ============================================================================ +// Elicitation defaults scenario +// ============================================================================ + +private suspend fun runElicitationDefaultsClient(serverUrl: String) { + createHttpClient().use { httpClient -> + val transport = StreamableHttpClientTransport(httpClient, serverUrl) + val client = Client( + clientInfo = Implementation("elicitation-defaults-test-client", "1.0.0"), + options = ClientOptions( + capabilities = ClientCapabilities( + elicitation = ClientCapabilities.elicitation, + ), + ), + ) + + // Register elicitation handler that returns empty content — SDK should fill in defaults + client.setElicitationHandler { _ -> + ElicitResult( + action = ElicitResult.Action.Accept, + content = JsonObject(emptyMap()), + ) + } + + client.connect(transport) + + val tools = client.listTools() + val testTool = tools.tools.find { it.name == "test_client_elicitation_defaults" } + ?: error("Test tool not found: test_client_elicitation_defaults") + + client.callTool( + CallToolRequest(CallToolRequestParams(name = testTool.name)), + ) + + client.close() + } +} + +// ============================================================================ +// SSE retry scenario +// ============================================================================ + +private suspend fun runSSERetryClient(serverUrl: String) { + createHttpClient().use { httpClient -> + val transport = StreamableHttpClientTransport(httpClient, serverUrl) + val client = Client( + clientInfo = Implementation("sse-retry-test-client", "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + client.connect(transport) + + client.listTools() + + client.callTool( + CallToolRequest(CallToolRequestParams(name = "test_reconnection")), + ) + + client.close() + } +} + +// ============================================================================ +// Register core scenarios +// ============================================================================ + +private fun registerCoreScenarios() { + scenarioHandlers["initialize"] = ::runBasicClient + scenarioHandlers["tools_call"] = ::runToolsCallClient + scenarioHandlers["elicitation-sep1034-client-defaults"] = ::runElicitationDefaultsClient + scenarioHandlers["sse-retry"] = ::runSSERetryClient +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceCompletions.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceCompletions.kt new file mode 100644 index 000000000..629713949 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceCompletions.kt @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.types.CompleteRequest +import io.modelcontextprotocol.kotlin.sdk.types.CompleteResult +import io.modelcontextprotocol.kotlin.sdk.types.Method + +fun Server.registerConformanceCompletions() { + onConnect { + val session = sessions.values.lastOrNull() ?: return@onConnect + session.setRequestHandler(Method.Defined.CompletionComplete) { _, _ -> + CompleteResult(CompleteResult.Completion(values = emptyList(), total = 0, hasMore = false)) + } + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformancePrompts.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformancePrompts.kt new file mode 100644 index 000000000..6c591e4fd --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformancePrompts.kt @@ -0,0 +1,114 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.types.EmbeddedResource +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.ImageContent +import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import kotlinx.coroutines.delay +import kotlin.time.Duration.Companion.milliseconds + +fun Server.registerConformancePrompts() { + // 1. Simple prompt + addPrompt( + name = "test_simple_prompt", + description = "test_simple_prompt", + ) { + GetPromptResult( + messages = listOf( + PromptMessage(Role.User, TextContent("This is a simple prompt for testing.")), + ), + ) + } + + // 2. Prompt with arguments + addPrompt( + name = "test_prompt_with_arguments", + description = "test_prompt_with_arguments", + arguments = listOf( + PromptArgument(name = "arg1", description = "First test argument", required = true), + PromptArgument(name = "arg2", description = "Second test argument", required = true), + ), + ) { request -> + val arg1 = request.arguments?.get("arg1") ?: "" + val arg2 = request.arguments?.get("arg2") ?: "" + GetPromptResult( + messages = listOf( + PromptMessage( + Role.User, + TextContent("Prompt with arguments: arg1='$arg1', arg2='$arg2'"), + ), + ), + ) + } + + // 3. Prompt with image + addPrompt( + name = "test_prompt_with_image", + description = "test_prompt_with_image", + ) { + GetPromptResult( + messages = listOf( + PromptMessage(Role.User, ImageContent(data = PNG_BASE64, mimeType = "image/png")), + PromptMessage(Role.User, TextContent("Please analyze the image above.")), + ), + ) + } + + // 4. Prompt with embedded resource + addPrompt( + name = "test_prompt_with_embedded_resource", + description = "test_prompt_with_embedded_resource", + arguments = listOf( + PromptArgument(name = "resourceUri", description = "URI of the resource to embed", required = true), + ), + ) { request -> + val resourceUri = request.arguments?.get("resourceUri") ?: "test://embedded-resource" + GetPromptResult( + messages = listOf( + PromptMessage( + Role.User, + EmbeddedResource( + resource = TextResourceContents( + text = "Embedded resource content for testing", + uri = resourceUri, + mimeType = "text/plain", + ), + ), + ), + PromptMessage(Role.User, TextContent("Please process the embedded resource above.")), + ), + ) + } + + // 5. Dynamic prompt + val server = this + addPrompt( + name = "test_dynamic_prompt", + description = "test_dynamic_prompt", + ) { + // Add a temporary prompt, triggering listChanged + server.addPrompt( + name = "test_dynamic_prompt_temp", + description = "Temporary dynamic prompt", + ) { + GetPromptResult( + messages = listOf( + PromptMessage(Role.User, TextContent("Temporary prompt response")), + ), + ) + } + delay(100.milliseconds) + // Remove the temporary prompt, triggering listChanged again + server.removePrompt("test_dynamic_prompt_temp") + GetPromptResult( + messages = listOf( + PromptMessage(Role.User, TextContent("Dynamic prompt executed successfully")), + ), + ) + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceResources.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceResources.kt new file mode 100644 index 000000000..a09903771 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceResources.kt @@ -0,0 +1,123 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.types.BlobResourceContents +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import kotlinx.coroutines.delay +import kotlin.time.Duration.Companion.milliseconds + +fun Server.registerConformanceResources() { + // 1. Static text resource + addResource( + uri = "test://static-text", + name = "static-text", + description = "A static text resource for testing", + mimeType = "text/plain", + ) { + ReadResourceResult( + listOf( + TextResourceContents( + text = "This is the content of the static text resource.", + uri = "test://static-text", + mimeType = "text/plain", + ), + ), + ) + } + + // 2. Static binary resource + addResource( + uri = "test://static-binary", + name = "static-binary", + description = "A static binary resource for testing", + mimeType = "image/png", + ) { + ReadResourceResult( + listOf( + BlobResourceContents( + blob = PNG_BASE64, + uri = "test://static-binary", + mimeType = "image/png", + ), + ), + ) + } + + // 3. Template resource + // Note: The SDK does not currently support addResourceTemplate(). + // Register as a static resource; template listing is handled separately. + addResource( + uri = "test://template/{id}/data", + name = "template", + description = "A template resource for testing", + mimeType = "application/json", + ) { request -> + ReadResourceResult( + listOf( + TextResourceContents( + text = "content for ${request.uri}", + uri = request.uri, + mimeType = "application/json", + ), + ), + ) + } + + // 4. Watched resource + addResource( + uri = "test://watched-resource", + name = "watched-resource", + description = "A watched resource for testing", + mimeType = "text/plain", + ) { + ReadResourceResult( + listOf( + TextResourceContents( + text = "Watched resource content.", + uri = "test://watched-resource", + mimeType = "text/plain", + ), + ), + ) + } + + // 5. Dynamic resource + val server = this + addResource( + uri = "test://dynamic-resource", + name = "dynamic-resource", + description = "A dynamic resource for testing", + mimeType = "text/plain", + ) { + // Add a temporary resource, triggering listChanged + server.addResource( + uri = "test://dynamic-resource-temp", + name = "dynamic-resource-temp", + description = "Temporary dynamic resource", + mimeType = "text/plain", + ) { + ReadResourceResult( + listOf( + TextResourceContents( + text = "Temporary resource content.", + uri = "test://dynamic-resource-temp", + mimeType = "text/plain", + ), + ), + ) + } + delay(100.milliseconds) + // Remove the temporary resource, triggering listChanged again + server.removeResource("test://dynamic-resource-temp") + ReadResourceResult( + listOf( + TextResourceContents( + text = "Dynamic resource content.", + uri = "test://dynamic-resource", + mimeType = "text/plain", + ), + ), + ) + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt new file mode 100644 index 000000000..b42077529 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt @@ -0,0 +1,47 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.application.install +import io.ktor.server.cio.CIO +import io.ktor.server.engine.embeddedServer +import io.ktor.server.plugins.contentnegotiation.ContentNegotiation +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcpStreamableHttp +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities + +fun main() { + val port = System.getenv("MCP_PORT")?.toIntOrNull() ?: 3001 + embeddedServer(CIO, port = port) { + install(ContentNegotiation) { + json(McpJson) + } + mcpStreamableHttp( + enableDnsRebindingProtection = true, + allowedHosts = listOf("localhost", "127.0.0.1", "localhost:$port", "127.0.0.1:$port"), + eventStore = InMemoryEventStore(), + ) { + createConformanceServer() + } + }.start(wait = true) +} + +fun createConformanceServer(): Server = Server( + serverInfo = Implementation("mcp-kotlin-sdk-conformance", "0.1.0"), + options = ServerOptions( + ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = true), + resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), + prompts = ServerCapabilities.Prompts(listChanged = true), + logging = ServerCapabilities.Logging, + completions = ServerCapabilities.Completions, + ), + ), +) { + registerConformanceTools() + registerConformanceResources() + registerConformancePrompts() + registerConformanceCompletions() +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt new file mode 100644 index 000000000..ebe7a5716 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt @@ -0,0 +1,604 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.types.AudioContent +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.EmbeddedResource +import io.modelcontextprotocol.kotlin.sdk.types.ImageContent +import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel +import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams +import io.modelcontextprotocol.kotlin.sdk.types.ProgressNotification +import io.modelcontextprotocol.kotlin.sdk.types.ProgressNotificationParams +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.SamplingMessage +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.delay +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.double +import kotlinx.serialization.json.jsonPrimitive +import kotlin.time.Duration.Companion.milliseconds + +// Minimal 1x1 PNG (base64) +internal const val PNG_BASE64 = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + +// Minimal WAV (base64) +internal const val WAV_BASE64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAAB9AAACABAAZGF0YQIAAAA=" + +@Suppress("LongMethod") +fun Server.registerConformanceTools() { + // 1. Simple text + addTool( + name = "test_simple_text", + description = "test_simple_text", + ) { + CallToolResult(listOf(TextContent("Simple text content"))) + } + + // 2. Image content + addTool( + name = "test_image_content", + description = "test_image_content", + ) { + CallToolResult(listOf(ImageContent(data = PNG_BASE64, mimeType = "image/png"))) + } + + // 3. Audio content + addTool( + name = "test_audio_content", + description = "test_audio_content", + ) { + CallToolResult(listOf(AudioContent(data = WAV_BASE64, mimeType = "audio/wav"))) + } + + // 4. Embedded resource + addTool( + name = "test_embedded_resource", + description = "test_embedded_resource", + ) { + CallToolResult( + listOf( + EmbeddedResource( + resource = TextResourceContents( + text = "This is an embedded resource content.", + uri = "test://embedded-resource", + mimeType = "text/plain", + ), + ), + ), + ) + } + + // 5. Multiple content types + addTool( + name = "test_multiple_content_types", + description = "test_multiple_content_types", + ) { + CallToolResult( + listOf( + TextContent("Simple text content"), + ImageContent(data = PNG_BASE64, mimeType = "image/png"), + EmbeddedResource( + resource = TextResourceContents( + text = "This is an embedded resource content.", + uri = "test://embedded-resource", + mimeType = "text/plain", + ), + ), + ), + ) + } + + // 6. Progress tool + addTool( + name = "test_tool_with_progress", + description = "test_tool_with_progress", + ) { request -> + val progressToken = request.meta?.progressToken + if (progressToken != null) { + notification( + ProgressNotification( + ProgressNotificationParams( + progressToken, + 0.0, + 100.0, + "Completed step 0 of 100", + ), + ), + ) + delay(50.milliseconds) + notification( + ProgressNotification( + ProgressNotificationParams( + progressToken, + 50.0, + 100.0, + "Completed step 50 of 100", + ), + ), + ) + delay(50.milliseconds) + notification( + ProgressNotification( + ProgressNotificationParams( + progressToken, + 100.0, + 100.0, + "Completed step 100 of 100", + ), + ), + ) + } + CallToolResult(listOf(TextContent("Simple text content"))) + } + + // 7. Error handling + addTool( + name = "test_error_handling", + description = "test_error_handling", + ) { + throw Exception("This tool intentionally returns an error for testing") + } + + // 8. Sampling + addTool( + name = "test_sampling", + description = "test_sampling", + inputSchema = ToolSchema( + properties = buildJsonObject { + put("prompt", buildJsonObject { put("type", JsonPrimitive("string")) }) + }, + required = listOf("prompt"), + ), + ) { request -> + val prompt = request.arguments?.get("prompt")?.jsonPrimitive?.content ?: "Hello" + val result = createMessage( + CreateMessageRequest( + CreateMessageRequestParams( + maxTokens = 10000, + messages = listOf(SamplingMessage(Role.User, TextContent(prompt))), + ), + ), + ) + CallToolResult(listOf(TextContent(result.content.toString()))) + } + + // 9. Elicitation + addTool( + name = "test_elicitation", + description = "test_elicitation", + inputSchema = ToolSchema( + properties = buildJsonObject { + put("message", buildJsonObject { put("type", JsonPrimitive("string")) }) + }, + required = listOf("message"), + ), + ) { request -> + val message = request.arguments?.get("message")?.jsonPrimitive?.content ?: "Please provide input" + val schema = ElicitRequestParams.RequestedSchema( + properties = buildJsonObject { + put( + "username", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("User's response")) + }, + ) + put( + "email", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("User's email address")) + }, + ) + }, + required = listOf("username", "email"), + ) + val result = createElicitation(message, schema) + CallToolResult(listOf(TextContent("User response: "))) + } + + // 10. Elicitation SEP1034 (defaults) + addTool( + name = "test_elicitation_sep1034_defaults", + description = "test_elicitation_sep1034_defaults", + ) { + val schema = ElicitRequestParams.RequestedSchema( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("User name")) + put("default", JsonPrimitive("John Doe")) + }, + ) + put( + "age", + buildJsonObject { + put("type", JsonPrimitive("integer")) + put("description", JsonPrimitive("User age")) + put("default", JsonPrimitive(30)) + }, + ) + put( + "score", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("User score")) + put("default", JsonPrimitive(95.5)) + }, + ) + put( + "status", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("User status")) + put("default", JsonPrimitive("active")) + put( + "enum", + JsonArray( + listOf(JsonPrimitive("active"), JsonPrimitive("inactive"), JsonPrimitive("pending")), + ), + ) + }, + ) + put( + "verified", + buildJsonObject { + put("type", JsonPrimitive("boolean")) + put("description", JsonPrimitive("Verification status")) + put("default", JsonPrimitive(true)) + }, + ) + }, + required = listOf("name", "age", "score", "status", "verified"), + ) + val result = createElicitation( + "Please review and update the form fields with defaults", + schema, + ) + CallToolResult(listOf(TextContent(result.content.toString()))) + } + + // 11. Elicitation SEP1330 enums + addTool( + name = "test_elicitation_sep1330_enums", + description = "test_elicitation_sep1330_enums", + ) { + val schema = ElicitRequestParams.RequestedSchema( + properties = buildJsonObject { + // Untitled single-select + put( + "untitledSingle", + buildJsonObject { + put("type", JsonPrimitive("string")) + put( + "enum", + JsonArray( + listOf(JsonPrimitive("option1"), JsonPrimitive("option2"), JsonPrimitive("option3")), + ), + ) + }, + ) + // Titled single-select + put( + "titledSingle", + buildJsonObject { + put("type", JsonPrimitive("string")) + put( + "oneOf", + JsonArray( + listOf( + buildJsonObject { + put("const", JsonPrimitive("value1")) + put("title", JsonPrimitive("First Option")) + }, + buildJsonObject { + put("const", JsonPrimitive("value2")) + put("title", JsonPrimitive("Second Option")) + }, + buildJsonObject { + put("const", JsonPrimitive("value3")) + put("title", JsonPrimitive("Third Option")) + }, + ), + ), + ) + }, + ) + // Legacy titled (deprecated) + put( + "legacyEnum", + buildJsonObject { + put("type", JsonPrimitive("string")) + put( + "oneOf", + JsonArray( + listOf( + buildJsonObject { + put("const", JsonPrimitive("opt1")) + put("title", JsonPrimitive("Option One")) + }, + buildJsonObject { + put("const", JsonPrimitive("opt2")) + put("title", JsonPrimitive("Option Two")) + }, + buildJsonObject { + put("const", JsonPrimitive("opt3")) + put("title", JsonPrimitive("Option Three")) + }, + ), + ), + ) + }, + ) + // Untitled multi-select + put( + "untitledMulti", + buildJsonObject { + put("type", JsonPrimitive("array")) + put( + "items", + buildJsonObject { + put("type", JsonPrimitive("string")) + put( + "enum", + JsonArray( + listOf( + JsonPrimitive("option1"), + JsonPrimitive("option2"), + JsonPrimitive("option3"), + ), + ), + ) + }, + ) + }, + ) + // Titled multi-select + put( + "titledMulti", + buildJsonObject { + put("type", JsonPrimitive("array")) + put( + "items", + buildJsonObject { + put("type", JsonPrimitive("string")) + put( + "oneOf", + JsonArray( + listOf( + buildJsonObject { + put("const", JsonPrimitive("value1")) + put("title", JsonPrimitive("First Choice")) + }, + buildJsonObject { + put("const", JsonPrimitive("value2")) + put("title", JsonPrimitive("Second Choice")) + }, + buildJsonObject { + put("const", JsonPrimitive("value3")) + put("title", JsonPrimitive("Third Choice")) + }, + ), + ), + ) + }, + ) + }, + ) + }, + ) + val result = createElicitation( + "Please review and update the form fields with defaults", + schema, + ) + CallToolResult(listOf(TextContent(result.content.toString()))) + } + + // 12. Dynamic tool + val server = this + addTool( + name = "test_dynamic_tool", + description = "test_dynamic_tool", + ) { + // Add a temporary tool, triggering listChanged + server.addTool( + name = "test_dynamic_tool_temp", + description = "Temporary dynamic tool", + ) { + CallToolResult(listOf(TextContent("Temporary tool response"))) + } + delay(100.milliseconds) + // Remove the temporary tool, triggering listChanged again + server.removeTool("test_dynamic_tool_temp") + CallToolResult(listOf(TextContent("Dynamic tool executed successfully"))) + } + + // 13. Logging tool + addTool( + name = "test_tool_with_logging", + description = "test_tool_with_logging", + ) { + sendLoggingMessage( + LoggingMessageNotification( + LoggingMessageNotificationParams( + level = LoggingLevel.Info, + data = JsonPrimitive("Tool execution started"), + logger = "conformance", + ), + ), + ) + delay(50.milliseconds) + sendLoggingMessage( + LoggingMessageNotification( + LoggingMessageNotificationParams( + level = LoggingLevel.Info, + data = JsonPrimitive("Tool processing data"), + logger = "conformance", + ), + ), + ) + delay(50.milliseconds) + sendLoggingMessage( + LoggingMessageNotification( + LoggingMessageNotificationParams( + level = LoggingLevel.Info, + data = JsonPrimitive("Tool execution completed"), + logger = "conformance", + ), + ), + ) + CallToolResult(listOf(TextContent("Simple text content"))) + } + + // 14. add_numbers — used by tools_call client scenario + addTool( + name = "add_numbers", + description = "Adds two numbers together", + inputSchema = ToolSchema( + properties = buildJsonObject { + put("a", buildJsonObject { put("type", JsonPrimitive("number")) }) + put("b", buildJsonObject { put("type", JsonPrimitive("number")) }) + }, + required = listOf("a", "b"), + ), + ) { request -> + val a = request.arguments?.get("a")?.jsonPrimitive?.double ?: 0.0 + val b = request.arguments?.get("b")?.jsonPrimitive?.double ?: 0.0 + val sum = a + b + CallToolResult(listOf(TextContent("The sum of $a and $b is $sum"))) + } + + // 15. test_reconnection — SEP-1699, closes SSE stream to test client reconnection + addTool( + name = "test_reconnection", + description = "Tests SSE stream disconnection and client reconnection (SEP-1699)", + ) { + // SDK limitation: cannot access the JSONRPC request ID from the tool handler + // to close the SSE stream. Return success text; this test may fail at the + // conformance runner level because the stream isn't actually closed. + delay(100.milliseconds) + CallToolResult( + listOf( + TextContent( + "Reconnection test completed successfully. " + + "If you received this, the client properly reconnected after stream closure.", + ), + ), + ) + } + + // 16. json_schema_2020_12_tool — SEP-1613 + addTool( + name = "json_schema_2020_12_tool", + description = "Tool with JSON Schema 2020-12 features for conformance testing (SEP-1613)", + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + }, + ) + put( + "address", + buildJsonObject { + put("type", JsonPrimitive("object")) + put( + "properties", + buildJsonObject { + put("street", buildJsonObject { put("type", JsonPrimitive("string")) }) + put("city", buildJsonObject { put("type", JsonPrimitive("string")) }) + }, + ) + }, + ) + }, + ), + ) { request -> + CallToolResult( + listOf(TextContent("JSON Schema 2020-12 tool called with: ${request.arguments}")), + ) + } + + // 17. test_client_elicitation_defaults — used by elicitation-sep1034-client-defaults scenario + addTool( + name = "test_client_elicitation_defaults", + description = "test_client_elicitation_defaults", + ) { + val schema = ElicitRequestParams.RequestedSchema( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("User name")) + put("default", JsonPrimitive("John Doe")) + }, + ) + put( + "age", + buildJsonObject { + put("type", JsonPrimitive("integer")) + put("description", JsonPrimitive("User age")) + put("default", JsonPrimitive(30)) + }, + ) + put( + "score", + buildJsonObject { + put("type", JsonPrimitive("number")) + put("description", JsonPrimitive("User score")) + put("default", JsonPrimitive(95.5)) + }, + ) + put( + "status", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("User status")) + put("default", JsonPrimitive("active")) + put( + "enum", + JsonArray( + listOf(JsonPrimitive("active"), JsonPrimitive("inactive"), JsonPrimitive("pending")), + ), + ) + }, + ) + put( + "verified", + buildJsonObject { + put("type", JsonPrimitive("boolean")) + put("description", JsonPrimitive("Verification status")) + put("default", JsonPrimitive(true)) + }, + ) + }, + required = emptyList(), + ) + val result = createElicitation( + "Please review and update the form fields with defaults", + schema, + ) + CallToolResult(listOf(TextContent("Elicitation completed: action=${result.action}, content=${result.content}"))) + } + + // 18. test-tool — simple tool used by auth scenarios + addTool( + name = "test-tool", + description = "Simple test tool for auth scenarios", + ) { + CallToolResult(listOf(TextContent("Test tool executed successfully"))) + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/InMemoryEventStore.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/InMemoryEventStore.kt new file mode 100644 index 000000000..ca7abb3e8 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/InMemoryEventStore.kt @@ -0,0 +1,46 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance + +import io.modelcontextprotocol.kotlin.sdk.server.EventStore +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import java.util.concurrent.ConcurrentHashMap +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +@OptIn(ExperimentalUuidApi::class) +class InMemoryEventStore : EventStore { + + private val events = ConcurrentHashMap>() + private val streamEvents = ConcurrentHashMap>() + + override suspend fun storeEvent(streamId: String, message: JSONRPCMessage): String { + val eventId = "$streamId::${System.currentTimeMillis()}_${Uuid.random()}" + events[eventId] = message to streamId + streamEvents.getOrPut(streamId) { mutableListOf() }.add(eventId) + return eventId + } + + override suspend fun replayEventsAfter( + lastEventId: String, + sender: suspend (eventId: String, message: JSONRPCMessage) -> Unit, + ): String { + val streamId = getStreamIdForEventId(lastEventId) + ?: error("Unknown event ID: $lastEventId") + val eventIds = streamEvents[streamId] ?: return streamId + + var found = false + for (eventId in eventIds) { + if (!found) { + if (eventId == lastEventId) found = true + continue + } + val (message, _) = events[eventId] ?: continue + sender(eventId, message) + } + return streamId + } + + override suspend fun getStreamIdForEventId(eventId: String): String? { + val idx = eventId.indexOf("::") + return if (idx >= 0) eventId.substring(0, idx) else null + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/JWTScenario.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/JWTScenario.kt new file mode 100644 index 000000000..9a20681c3 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/JWTScenario.kt @@ -0,0 +1,189 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.request.forms.submitForm +import io.ktor.http.Parameters +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.put +import java.security.KeyFactory +import java.security.Signature +import java.security.spec.PKCS8EncodedKeySpec +import java.util.Base64 +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +// Client Credentials JWT scenario +internal suspend fun runClientCredentialsJwt(serverUrl: String) { + val ctx = conformanceContext() + val clientId = ctx.requiredString("client_id") + val privateKeyPem = ctx.requiredString("private_key_pem") + val signingAlgorithm = ctx["signing_algorithm"]?.jsonPrimitive?.content ?: "ES256" + + val httpClient = HttpClient(CIO) { + install(SSE) + followRedirects = false + } + + httpClient.use { client -> + val resourceMetadata = discoverResourceMetadata(client, serverUrl) + val authServer = resourceMetadata["authorization_servers"]?.jsonArray?.firstOrNull()?.jsonPrimitive?.content + ?: error("No authorization_servers in resource metadata") + val oauthMetadata = fetchOAuthMetadata(client, authServer) + val tokenEndpoint = oauthMetadata["token_endpoint"]?.jsonPrimitive?.content + ?: error("No token_endpoint in AS metadata") + val issuer = oauthMetadata["issuer"]?.jsonPrimitive?.content + ?: error("No issuer in AS metadata") + + // Create JWT client assertion + val assertion = createJwtAssertion(clientId, issuer, privateKeyPem, signingAlgorithm) + + // Exchange for token + val tokenResponse = client.submitForm( + url = tokenEndpoint, + formParameters = Parameters.build { + append("grant_type", "client_credentials") + append("client_id", clientId) + append("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + append("client_assertion", assertion) + }, + ) + val accessToken = extractAccessToken(tokenResponse) + + withBearerToken(accessToken) { authedClient -> + val transport = StreamableHttpClientTransport(authedClient, serverUrl) + val client = Client( + clientInfo = Implementation("conformance-client-credentials-jwt", "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + client.connect(transport) + client.listTools() + client.close() + } + } +} + +// JWT Assertion +@OptIn(ExperimentalUuidApi::class) +private fun createJwtAssertion( + clientId: String, + audience: String, + privateKeyPem: String, + algorithm: String, +): String { + val header = buildJsonObject { + put("alg", algorithm) + put("typ", "JWT") + }.toString() + + val now = System.currentTimeMillis() / 1000 + val payload = buildJsonObject { + put("iss", clientId) + put("sub", clientId) + put("aud", audience) + put("iat", now) + put("exp", now + 300) + put("jti", Uuid.random().toString()) + }.toString() + + val headerB64 = Base64.getUrlEncoder().withoutPadding().encodeToString(header.toByteArray()) + val payloadB64 = Base64.getUrlEncoder().withoutPadding().encodeToString(payload.toByteArray()) + val signingInput = "$headerB64.$payloadB64" + + val signature = signJwt(signingInput, privateKeyPem, algorithm) + return "$signingInput.$signature" +} + +private fun signJwt(input: String, privateKeyPem: String, algorithm: String): String { + val pemBody = privateKeyPem + .replace("-----BEGIN PRIVATE KEY-----", "") + .replace("-----END PRIVATE KEY-----", "") + .replace("-----BEGIN EC PRIVATE KEY-----", "") + .replace("-----END EC PRIVATE KEY-----", "") + .replace("-----BEGIN RSA PRIVATE KEY-----", "") + .replace("-----END RSA PRIVATE KEY-----", "") + .replace("\n", "") + .replace("\r", "") + .trim() + + val keyBytes = Base64.getDecoder().decode(pemBody) + val keySpec = PKCS8EncodedKeySpec(keyBytes) + + val (keyAlgorithm, signatureAlgorithm) = when (algorithm) { + "ES256" -> "EC" to "SHA256withECDSA" + "RS256" -> "RSA" to "SHA256withRSA" + else -> error("Unsupported signing algorithm: $algorithm") + } + + val keyFactory = KeyFactory.getInstance(keyAlgorithm) + val privateKey = keyFactory.generatePrivate(keySpec) + + val sig = Signature.getInstance(signatureAlgorithm) + sig.initSign(privateKey) + sig.update(input.toByteArray()) + val rawSignature = sig.sign() + + // For EC, convert DER to raw r||s format for JWS + val signatureBytes = if (keyAlgorithm == "EC") { + derToRawEcSignature(rawSignature) + } else { + rawSignature + } + + return Base64.getUrlEncoder().withoutPadding().encodeToString(signatureBytes) +} + +private fun derToRawEcSignature(der: ByteArray): ByteArray { + // DER format: 0x30 len 0x02 rLen r 0x02 sLen s + require(der.size >= 2) { "DER signature too short" } + + var offset = 2 // skip SEQUENCE tag and length + if (der[1].toInt() and 0x80 != 0) { + offset += (der[1].toInt() and 0x7f) + } + + // Read r + require(offset < der.size) { "DER signature truncated before r tag" } + check(der[offset] == 0x02.toByte()) { "Expected INTEGER tag for r" } + offset++ + require(offset < der.size) { "DER signature truncated before r length" } + val rLen = der[offset].toInt() and 0xff + offset++ + require(offset + rLen <= der.size) { "DER signature truncated in r value" } + val r = der.copyOfRange(offset, offset + rLen) + offset += rLen + + // Read s + require(offset < der.size) { "DER signature truncated before s tag" } + check(der[offset] == 0x02.toByte()) { "Expected INTEGER tag for s" } + offset++ + require(offset < der.size) { "DER signature truncated before s length" } + val sLen = der[offset].toInt() and 0xff + offset++ + require(offset + sLen <= der.size) { "DER signature truncated in s value" } + val s = der.copyOfRange(offset, offset + sLen) + + // Each component should be 32 bytes for P-256 + val componentLen = 32 + val result = ByteArray(componentLen * 2) + + // Copy r (may need padding or trimming of leading zero) + val rStart = if (r.size > componentLen) r.size - componentLen else 0 + val rDest = if (r.size < componentLen) componentLen - r.size else 0 + r.copyInto(result, rDest, rStart, r.size) + + // Copy s + val sStart = if (s.size > componentLen) s.size - componentLen else 0 + val sDest = componentLen + if (s.size < componentLen) componentLen - s.size else 0 + s.copyInto(result, sDest, sStart, s.size) + + return result +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/authCodeFlow.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/authCodeFlow.kt new file mode 100644 index 000000000..c59e8e003 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/authCodeFlow.kt @@ -0,0 +1,154 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpSend +import io.ktor.client.plugins.plugin +import io.ktor.client.plugins.sse.SSE +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonPrimitive +import java.util.UUID + +internal suspend fun runAuthClient(serverUrl: String) { + val httpClient = HttpClient(CIO) { + install(SSE) + followRedirects = false + } + + var accessToken: String? = null + var authAttempts = 0 + // Cache discovery and credentials across retries + var cachedDiscovery: DiscoveryResult? = null + var cachedCredentials: ClientCredentials? = null + + httpClient.plugin(HttpSend).intercept { request -> + // Add existing token if available + if (accessToken != null) { + request.headers.remove(HttpHeaders.Authorization) + request.headers.append(HttpHeaders.Authorization, "Bearer $accessToken") + } + + val response = execute(request) + val status = response.response.status + + // Determine if we need to (re-)authorize + val needsAuth = status == HttpStatusCode.Unauthorized + val wwwAuth = response.response.headers[HttpHeaders.WWWAuthenticate] ?: "" + val stepUpScope = if (status == HttpStatusCode.Forbidden) parseStepUpScope(wwwAuth) else null + val needsStepUp = stepUpScope != null + + if ((needsAuth || needsStepUp) && authAttempts < 3) { + authAttempts++ + + // Discover metadata (cache across retries) + if (cachedDiscovery == null) { + val resourceMetadataUrl = extractParam(wwwAuth, "resource_metadata") + cachedDiscovery = discoverOAuthMetadata(httpClient, serverUrl, resourceMetadataUrl) + } + val discovery: DiscoveryResult = cachedDiscovery + + // Validate PRM resource matches server URL (RFC 8707) + val discoveredResource = discovery.resourceUrl + if (discoveredResource != null) { + val normalizedResource = discoveredResource.trimEnd('/') + val normalizedServerUrl = serverUrl.trimEnd('/') + val matches = normalizedServerUrl == normalizedResource || + normalizedServerUrl.startsWith("$normalizedResource/") + require(matches) { + "PRM resource mismatch: resource='$discoveredResource' does not match server URL='$serverUrl'" + } + } + + val metadata = discovery.asMetadata + + val authEndpoint = metadata["authorization_endpoint"]?.jsonPrimitive?.content + ?: error("No authorization_endpoint in metadata") + val tokenEndpoint = metadata["token_endpoint"]?.jsonPrimitive?.content + ?: error("No token_endpoint in metadata") + + val tokenEndpointAuthMethods = metadata["token_endpoint_auth_methods_supported"] + ?.jsonArray?.map { it.jsonPrimitive.content } + ?: listOf("client_secret_post") + val tokenAuthMethod = tokenEndpointAuthMethods.firstOrNull() ?: "client_secret_post" + + // Verify PKCE support + verifyPkceSupport(metadata) + + // Resolve client credentials (cache across retries) + if (cachedCredentials == null) { + cachedCredentials = resolveClientCredentials(httpClient, metadata) + } + val creds: ClientCredentials = cachedCredentials + + // Determine scope + val scope = if (needsStepUp) { + stepUpScope + } else { + val wwwAuthScope = extractParam(wwwAuth, "scope") + selectScope(wwwAuthScope, discovery.scopesSupported) + } + + // PKCE + val codeVerifier = generateCodeVerifier() + val codeChallenge = generateCodeChallenge(codeVerifier) + + // CSRF state parameter + val state = UUID.randomUUID().toString() + + // Build authorization URL + val authUrl = buildAuthorizationUrl( + authEndpoint, + creds.clientId, + CALLBACK_URL, + codeChallenge, + scope, + discovery.resourceUrl, + state, + ) + + // Follow the authorization redirect to get auth code + val authCode = followAuthorizationRedirect(httpClient, authUrl, CALLBACK_URL, state) + + // Exchange code for tokens + accessToken = exchangeCodeForTokens( + httpClient, + tokenEndpoint, + authCode, + creds.clientId, + creds.clientSecret, + CALLBACK_URL, + codeVerifier, + tokenAuthMethod, + discovery.resourceUrl, + ) + + // Retry the original request with the token + request.headers.remove(HttpHeaders.Authorization) + request.headers.append(HttpHeaders.Authorization, "Bearer $accessToken") + execute(request) + } else { + response + } + } + + httpClient.use { client -> + val transport = StreamableHttpClientTransport(client, serverUrl) + val mcpClient = Client( + clientInfo = Implementation("test-auth-client", "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + mcpClient.connect(transport) + mcpClient.listTools() + mcpClient.callTool(CallToolRequest(CallToolRequestParams(name = "test-tool"))) + mcpClient.close() + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/basicScenario.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/basicScenario.kt new file mode 100644 index 000000000..a0d6d864d --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/basicScenario.kt @@ -0,0 +1,54 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.request.forms.submitForm +import io.ktor.client.request.header +import io.ktor.http.HttpHeaders +import io.ktor.http.Parameters +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import java.util.Base64 + +// Client Credentials Basic scenario +internal suspend fun runClientCredentialsBasic(serverUrl: String) { + val ctx = conformanceContext() + val clientId = ctx.requiredString("client_id") + val clientSecret = ctx.requiredString("client_secret") + + val httpClient = HttpClient(CIO) { + install(SSE) + followRedirects = false + } + + httpClient.use { client -> + val tokenEndpoint = discoverTokenEndpoint(client, serverUrl) + + // Exchange credentials for token using Basic auth + val basicAuth = Base64.getEncoder().encodeToString("$clientId:$clientSecret".toByteArray()) + val tokenResponse = client.submitForm( + url = tokenEndpoint, + formParameters = Parameters.build { + append("grant_type", "client_credentials") + }, + ) { + header(HttpHeaders.Authorization, "Basic $basicAuth") + } + val accessToken = extractAccessToken(tokenResponse) + + withBearerToken(accessToken) { authedClient -> + val transport = StreamableHttpClientTransport(authedClient, serverUrl) + val client = Client( + clientInfo = Implementation("conformance-client-credentials-basic", "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + client.connect(transport) + client.listTools() + client.close() + } + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/clientRegistration.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/clientRegistration.kt new file mode 100644 index 000000000..8e84a1587 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/clientRegistration.kt @@ -0,0 +1,73 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.ktor.client.HttpClient +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.contentType +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.put + +internal data class ClientCredentials(val clientId: String, val clientSecret: String?) + +/** + * Resolve client credentials per spec priority: + * 1. Pre-registered (from MCP_CONFORMANCE_CONTEXT) + * 2. CIMD (client_id_metadata_document_supported) + * 3. Dynamic registration (registration_endpoint) + * 4. Error + */ +internal suspend fun resolveClientCredentials(httpClient: HttpClient, asMetadata: JsonObject): ClientCredentials { + // 1. Pre-registered + val contextJson = System.getenv("MCP_CONFORMANCE_CONTEXT") + if (contextJson != null) { + val ctx = json.parseToJsonElement(contextJson).jsonObject + val clientId = ctx["client_id"]?.jsonPrimitive?.content + if (clientId != null) { + val clientSecret = ctx["client_secret"]?.jsonPrimitive?.content + return ClientCredentials(clientId, clientSecret) + } + } + + // 2. CIMD + val cimdSupported = asMetadata["client_id_metadata_document_supported"] + ?.jsonPrimitive?.content?.toBoolean() ?: false + if (cimdSupported) { + return ClientCredentials(CIMD_CLIENT_METADATA_URL, null) + } + + // 3. Dynamic registration + val registrationEndpoint = asMetadata["registration_endpoint"]?.jsonPrimitive?.content + if (registrationEndpoint != null) { + return dynamicClientRegistration(httpClient, registrationEndpoint) + } + + error("No way to register client: no pre-registered credentials, CIMD not supported, and no registration_endpoint") +} + +private suspend fun dynamicClientRegistration( + httpClient: HttpClient, + registrationEndpoint: String, +): ClientCredentials { + val regBody = buildJsonObject { + put("client_name", "test-auth-client") + put("redirect_uris", buildJsonArray { add(kotlinx.serialization.json.JsonPrimitive(CALLBACK_URL)) }) + put("grant_types", buildJsonArray { add(kotlinx.serialization.json.JsonPrimitive("authorization_code")) }) + put("response_types", buildJsonArray { add(kotlinx.serialization.json.JsonPrimitive("code")) }) + put("token_endpoint_auth_method", "client_secret_post") + } + + val response = httpClient.post(registrationEndpoint) { + contentType(ContentType.Application.Json) + setBody(regBody.toString()) + } + val regJson = json.parseToJsonElement(response.bodyAsText()).jsonObject + val clientId = regJson["client_id"]?.jsonPrimitive?.content ?: error("No client_id in registration response") + val clientSecret = regJson["client_secret"]?.jsonPrimitive?.content + return ClientCredentials(clientId, clientSecret) +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/crossAppAccessScenario.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/crossAppAccessScenario.kt new file mode 100644 index 000000000..db4b7f0f6 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/crossAppAccessScenario.kt @@ -0,0 +1,82 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.request.forms.submitForm +import io.ktor.client.request.header +import io.ktor.http.HttpHeaders +import io.ktor.http.Parameters +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonPrimitive +import java.util.Base64 + +// SEP-990 Enterprise Managed OAuth: Cross-App Access complete flow +internal suspend fun runCrossAppAccess(serverUrl: String) { + val ctx = conformanceContext() + val clientId = ctx.requiredString("client_id") + val clientSecret = ctx.requiredString("client_secret") + val idpIdToken = ctx.requiredString("idp_id_token") + val idpTokenEndpoint = ctx.requiredString("idp_token_endpoint") + + val httpClient = HttpClient(CIO) { + install(SSE) + followRedirects = false + } + + httpClient.use { client -> + // Discover PRM + AS metadata + val resourceMeta = discoverResourceMetadata(client, serverUrl) + val resourceUrl = resourceMeta["resource"]?.jsonPrimitive?.content + ?: error("No resource in resource metadata") + val authServer = resourceMeta["authorization_servers"]?.jsonArray?.firstOrNull()?.jsonPrimitive?.content + ?: error("No authorization_servers in resource metadata") + val asMeta = fetchOAuthMetadata(client, authServer) + val tokenEndpoint = asMeta["token_endpoint"]?.jsonPrimitive?.content + ?: error("No token_endpoint in AS metadata") + + // RFC 8693 Token Exchange at IDP: exchange ID token for ID-JAG + val idpResponse = client.submitForm( + url = idpTokenEndpoint, + formParameters = Parameters.build { + append("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + append("subject_token", idpIdToken) + append("subject_token_type", "urn:ietf:params:oauth:token-type:id_token") + append("requested_token_type", "urn:ietf:params:oauth:token-type:id-jag") + append("audience", authServer) + append("resource", resourceUrl) + }, + ) + val idJag = extractAccessToken(idpResponse) + + // RFC 7523 JWT Bearer Grant at AS with Basic auth + val basicAuth = Base64.getEncoder().encodeToString("$clientId:$clientSecret".toByteArray()) + val asResponse = client.submitForm( + url = tokenEndpoint, + formParameters = Parameters.build { + append("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") + append("assertion", idJag) + }, + ) { + header(HttpHeaders.Authorization, "Basic $basicAuth") + } + val accessToken = extractAccessToken(asResponse) + + // Use access token for MCP requests + withBearerToken(accessToken) { authedClient -> + val transport = StreamableHttpClientTransport(authedClient, serverUrl) + val mcpClient = Client( + clientInfo = Implementation("conformance-cross-app-access", "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + mcpClient.connect(transport) + mcpClient.listTools() + mcpClient.close() + } + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/discovery.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/discovery.kt new file mode 100644 index 000000000..73979db5f --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/discovery.kt @@ -0,0 +1,48 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.ktor.client.HttpClient +import io.ktor.client.request.get +import io.ktor.client.statement.bodyAsText +import io.ktor.http.isSuccess +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive + +internal data class DiscoveryResult( + val asMetadata: JsonObject, + val resourceUrl: String?, + val scopesSupported: List?, +) + +internal suspend fun discoverOAuthMetadata( + httpClient: HttpClient, + serverUrl: String, + resourceMetadataUrl: String?, +): DiscoveryResult { + // Get resource metadata + val resourceMeta = if (resourceMetadataUrl != null) { + val resp = httpClient.get(resourceMetadataUrl) + if (!resp.status.isSuccess()) { + error("Failed to fetch resource metadata from $resourceMetadataUrl: ${resp.status}") + } + json.parseToJsonElement(resp.bodyAsText()).jsonObject + } else { + discoverResourceMetadata(httpClient, serverUrl) + } + + val resourceUrl = resourceMeta["resource"]?.jsonPrimitive?.content + val scopesSupported = resourceMeta["scopes_supported"] + ?.jsonArray?.map { it.jsonPrimitive.content } + val authServer = resourceMeta["authorization_servers"]?.jsonArray?.firstOrNull()?.jsonPrimitive?.content + + val oauthMeta = if (authServer != null) { + fetchOAuthMetadata(httpClient, authServer) + } else { + // Fallback: try well-known on server URL origin + val origin = extractOrigin(serverUrl) + fetchOAuthMetadata(httpClient, origin) + } + + return DiscoveryResult(oauthMeta, resourceUrl, scopesSupported) +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/pkce.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/pkce.kt new file mode 100644 index 000000000..e1c5419ed --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/pkce.kt @@ -0,0 +1,30 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonPrimitive +import java.security.MessageDigest +import java.util.Base64 + +internal fun generateCodeVerifier(): String { + val bytes = ByteArray(32) + java.security.SecureRandom().nextBytes(bytes) + return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes) +} + +internal fun generateCodeChallenge(verifier: String): String { + val digest = MessageDigest.getInstance("SHA-256").digest(verifier.toByteArray(Charsets.US_ASCII)) + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest) +} + +/** + * Verify that the AS metadata advertises S256 in code_challenge_methods_supported. + * Abort if PKCE S256 is not supported. + */ +internal fun verifyPkceSupport(asMetadata: JsonObject) { + val methods = asMetadata["code_challenge_methods_supported"] + ?.jsonArray?.map { it.jsonPrimitive.content } + require(methods != null && "S256" in methods) { + "Authorization server does not support PKCE S256 (code_challenge_methods_supported: $methods)" + } +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/registration.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/registration.kt new file mode 100644 index 000000000..4964f20d2 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/registration.kt @@ -0,0 +1,32 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.modelcontextprotocol.kotlin.sdk.conformance.scenarioHandlers + +// Registration +fun registerAuthScenarios() { + val authScenarios = listOf( + "auth/metadata-default", + "auth/metadata-var1", + "auth/metadata-var2", + "auth/metadata-var3", + "auth/basic-cimd", + "auth/scope-from-www-authenticate", + "auth/scope-from-scopes-supported", + "auth/scope-omitted-when-undefined", + "auth/scope-step-up", + "auth/scope-retry-limit", + "auth/token-endpoint-auth-basic", + "auth/token-endpoint-auth-post", + "auth/token-endpoint-auth-none", + "auth/resource-mismatch", + "auth/pre-registration", + "auth/2025-03-26-oauth-metadata-backcompat", + "auth/2025-03-26-oauth-endpoint-fallback", + ) + for (name in authScenarios) { + scenarioHandlers[name] = ::runAuthClient + } + scenarioHandlers["auth/client-credentials-jwt"] = ::runClientCredentialsJwt + scenarioHandlers["auth/client-credentials-basic"] = ::runClientCredentialsBasic + scenarioHandlers["auth/cross-app-access-complete-flow"] = ::runCrossAppAccess +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/scopeHandling.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/scopeHandling.kt new file mode 100644 index 000000000..33b69ceb0 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/scopeHandling.kt @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +internal fun extractParam(wwwAuth: String, param: String): String? { + val regex = Regex("""$param="([^"]+)"""") + return regex.find(wwwAuth)?.groupValues?.get(1) +} + +/** + * Select scope per MCP spec priority: + * 1. scope from WWW-Authenticate header + * 2. scopes_supported from Protected Resource Metadata (space-joined) + * 3. null (omit scope entirely) + */ +internal fun selectScope(wwwAuthScope: String?, scopesSupported: List?): String? { + if (wwwAuthScope != null) return wwwAuthScope + if (!scopesSupported.isNullOrEmpty()) return scopesSupported.joinToString(" ") + return null +} + +/** + * Detect 403 with error="insufficient_scope" and extract the new scope. + * Returns the scope string if step-up is needed, null otherwise. + */ +internal fun parseStepUpScope(wwwAuth: String?): String? { + if (wwwAuth == null) return null + val error = extractParam(wwwAuth, "error") + if (error != "insufficient_scope") return null + return extractParam(wwwAuth, "scope") +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/tokenExchange.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/tokenExchange.kt new file mode 100644 index 000000000..0c79bb97d --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/tokenExchange.kt @@ -0,0 +1,174 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.ktor.client.HttpClient +import io.ktor.client.request.forms.submitForm +import io.ktor.client.request.get +import io.ktor.client.request.header +import io.ktor.client.statement.bodyAsText +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.Parameters +import io.ktor.http.isSuccess +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import java.net.URI +import java.net.URLEncoder +import java.util.Base64 + +internal fun buildAuthorizationUrl( + authEndpoint: String, + clientId: String, + redirectUri: String, + codeChallenge: String, + scope: String?, + resource: String?, + state: String, +): String { + val params = buildString { + append("response_type=code") + append("&client_id=${URLEncoder.encode(clientId, "UTF-8")}") + append("&redirect_uri=${URLEncoder.encode(redirectUri, "UTF-8")}") + append("&code_challenge=${URLEncoder.encode(codeChallenge, "UTF-8")}") + append("&code_challenge_method=S256") + append("&state=${URLEncoder.encode(state, "UTF-8")}") + if (scope != null) { + append("&scope=${URLEncoder.encode(scope, "UTF-8")}") + } + if (resource != null) { + append("&resource=${URLEncoder.encode(resource, "UTF-8")}") + } + } + return if (authEndpoint.contains("?")) "$authEndpoint&$params" else "$authEndpoint?$params" +} + +internal suspend fun followAuthorizationRedirect( + httpClient: HttpClient, + authUrl: String, + expectedCallbackUrl: String, + expectedState: String, +): String { + val response = httpClient.get(authUrl) + + if (response.status == HttpStatusCode.Found || + response.status == HttpStatusCode.MovedPermanently || + response.status == HttpStatusCode.TemporaryRedirect || + response.status == HttpStatusCode.SeeOther + ) { + val location = response.headers[HttpHeaders.Location] + ?: error("No Location header in redirect response") + + require(location.startsWith(expectedCallbackUrl)) { + "Redirect location does not match expected callback URL" + } + + val uri = URI(location) + val queryParams = uri.query?.split("&")?.mapNotNull { + val parts = it.split("=", limit = 2) + if (parts.size == 2) parts[0] to java.net.URLDecoder.decode(parts[1], "UTF-8") else null + }?.toMap() ?: emptyMap() + + val returnedState = queryParams["state"] + require(returnedState == expectedState) { + "State parameter mismatch in authorization redirect" + } + + return queryParams["code"] ?: error("No authorization code in redirect response") + } + + error("Expected redirect from auth endpoint, got ${response.status}") +} + +internal suspend fun exchangeCodeForTokens( + httpClient: HttpClient, + tokenEndpoint: String, + code: String, + clientId: String, + clientSecret: String?, + redirectUri: String, + codeVerifier: String, + tokenAuthMethod: String, + resource: String?, +): String { + val response = when (tokenAuthMethod) { + "client_secret_basic" -> { + val basicAuth = Base64.getEncoder() + .encodeToString("$clientId:${clientSecret ?: ""}".toByteArray()) + httpClient.submitForm( + url = tokenEndpoint, + formParameters = Parameters.build { + append("grant_type", "authorization_code") + append("code", code) + append("redirect_uri", redirectUri) + append("code_verifier", codeVerifier) + if (resource != null) { + append("resource", resource) + } + }, + ) { + header(HttpHeaders.Authorization, "Basic $basicAuth") + } + } + + "none" -> { + httpClient.submitForm( + url = tokenEndpoint, + formParameters = Parameters.build { + append("grant_type", "authorization_code") + append("code", code) + append("client_id", clientId) + append("redirect_uri", redirectUri) + append("code_verifier", codeVerifier) + if (resource != null) { + append("resource", resource) + } + }, + ) + } + + else -> { + // client_secret_post (default) + httpClient.submitForm( + url = tokenEndpoint, + formParameters = Parameters.build { + append("grant_type", "authorization_code") + append("code", code) + append("client_id", clientId) + if (clientSecret != null) { + append("client_secret", clientSecret) + } + append("redirect_uri", redirectUri) + append("code_verifier", codeVerifier) + if (resource != null) { + append("resource", resource) + } + }, + ) + } + } + + // Check HTTP status + if (!response.status.isSuccess()) { + val body = response.bodyAsText() + val errorDetail = try { + val obj = json.parseToJsonElement(body).jsonObject + val err = obj["error"]?.jsonPrimitive?.content ?: "unknown" + val desc = obj["error_description"]?.jsonPrimitive?.content + if (desc != null) "$err: $desc" else err + } catch (_: Exception) { + body + } + error("Token exchange failed (${response.status}): $errorDetail") + } + + val tokenJson = json.parseToJsonElement(response.bodyAsText()).jsonObject + + // Check for error field in response body (some servers return 200 with error) + val errorField = tokenJson["error"]?.jsonPrimitive?.content + if (errorField != null) { + val desc = tokenJson["error_description"]?.jsonPrimitive?.content + error("Token exchange error: $errorField${if (desc != null) " - $desc" else ""}") + } + + return tokenJson["access_token"]?.jsonPrimitive?.content + ?: error("No access_token in token response") +} diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/utils.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/utils.kt new file mode 100644 index 000000000..4bc75a0e3 --- /dev/null +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/auth/utils.kt @@ -0,0 +1,120 @@ +package io.modelcontextprotocol.kotlin.sdk.conformance.auth + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpSend +import io.ktor.client.plugins.plugin +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.request.get +import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.bodyAsText +import io.ktor.http.HttpHeaders +import io.ktor.http.isSuccess +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import java.net.URI +import kotlin.text.ifEmpty + +internal val json = Json { ignoreUnknownKeys = true } + +internal fun conformanceContext(): JsonObject { + val contextJson = System.getenv("MCP_CONFORMANCE_CONTEXT") + ?: error("MCP_CONFORMANCE_CONTEXT not set") + return json.parseToJsonElement(contextJson).jsonObject +} + +internal fun JsonObject.requiredString(key: String): String = this[key]?.jsonPrimitive?.content ?: error("Missing $key") + +internal const val CIMD_CLIENT_METADATA_URL = "https://conformance-test.local/client-metadata.json" +internal const val CALLBACK_URL = "http://localhost:3000/callback" + +internal fun extractOrigin(url: String): String { + val uri = URI(url) + return "${uri.scheme}://${uri.host}${if (uri.port > 0) ":${uri.port}" else ""}" +} + +internal suspend fun discoverResourceMetadata(httpClient: HttpClient, serverUrl: String): JsonObject { + val origin = extractOrigin(serverUrl) + val path = URI(serverUrl).path.ifEmpty { "/" } + + // Try RFC 9728 format first: /.well-known/oauth-protected-resource/ + val wellKnownUrl = "$origin/.well-known/oauth-protected-resource$path" + val response = httpClient.get(wellKnownUrl) + if (response.status.isSuccess()) { + return json.parseToJsonElement(response.bodyAsText()).jsonObject + } + + // Fallback: try root + val fallbackUrl = "$origin/.well-known/oauth-protected-resource" + val fallbackResponse = httpClient.get(fallbackUrl) + if (!fallbackResponse.status.isSuccess()) { + error( + "Failed to discover resource metadata at $wellKnownUrl (${response.status}) and $fallbackUrl (${fallbackResponse.status})", + ) + } + return json.parseToJsonElement(fallbackResponse.bodyAsText()).jsonObject +} + +internal suspend fun fetchOAuthMetadata(httpClient: HttpClient, authServerUrl: String): JsonObject { + val origin = extractOrigin(authServerUrl) + val path = URI(authServerUrl).path.ifEmpty { "/" } + + // RFC 8414 §3: /.well-known/oauth-authorization-server/ + val oauthUrl = "$origin/.well-known/oauth-authorization-server$path" + val oauthResponse = httpClient.get(oauthUrl) + if (oauthResponse.status.isSuccess()) { + return json.parseToJsonElement(oauthResponse.bodyAsText()).jsonObject + } + + // OIDC Discovery with path insertion: /.well-known/openid-configuration/ + val oidcPathUrl = "$origin/.well-known/openid-configuration$path" + val oidcPathResponse = httpClient.get(oidcPathUrl) + if (oidcPathResponse.status.isSuccess()) { + return json.parseToJsonElement(oidcPathResponse.bodyAsText()).jsonObject + } + + // Fallback: OpenID Connect discovery (issuer + /.well-known/openid-configuration) + val oidcUrl = "$authServerUrl/.well-known/openid-configuration" + val oidcResponse = httpClient.get(oidcUrl) + if (oidcResponse.status.isSuccess()) { + return json.parseToJsonElement(oidcResponse.bodyAsText()).jsonObject + } + + error( + "Failed to fetch OAuth metadata from $oauthUrl (${oauthResponse.status}) and $oidcUrl (${oidcResponse.status})", + ) +} + +internal suspend fun discoverTokenEndpoint(httpClient: HttpClient, serverUrl: String): String { + val resourceMetadata = discoverResourceMetadata(httpClient, serverUrl) + val authServer = resourceMetadata["authorization_servers"]?.jsonArray?.firstOrNull()?.jsonPrimitive?.content + ?: error("No authorization_servers in resource metadata") + + val oauthMetadata = fetchOAuthMetadata(httpClient, authServer) + return oauthMetadata["token_endpoint"]?.jsonPrimitive?.content + ?: error("No token_endpoint") +} + +internal suspend fun extractAccessToken(tokenResponse: HttpResponse): String { + if (!tokenResponse.status.isSuccess()) { + error("Token request failed: ${tokenResponse.status}") + } + val tokenJson = json.parseToJsonElement(tokenResponse.bodyAsText()).jsonObject + return tokenJson["access_token"]?.jsonPrimitive?.content + ?: error("No access_token in token response") +} + +internal suspend fun withBearerToken(accessToken: String, block: suspend (HttpClient) -> T): T { + val client = HttpClient(CIO) { + install(SSE) + } + client.plugin(HttpSend).intercept { request -> + request.headers.remove(HttpHeaders.Authorization) + request.headers.append(HttpHeaders.Authorization, "Bearer $accessToken") + execute(request) + } + return client.use { block(it) } +} diff --git a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt deleted file mode 100644 index 09ea50bf9..000000000 --- a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceClient.kt +++ /dev/null @@ -1,85 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.conformance - -import io.github.oshai.kotlinlogging.KotlinLogging -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.sse.SSE -import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport -import io.modelcontextprotocol.kotlin.sdk.shared.Transport -import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest -import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams -import io.modelcontextprotocol.kotlin.sdk.types.Implementation -import kotlinx.coroutines.runBlocking -import kotlinx.serialization.json.JsonPrimitive -import kotlinx.serialization.json.buildJsonObject - -private val logger = KotlinLogging.logger {} - -fun main(args: Array) { - require(args.isNotEmpty()) { - "Server URL must be provided as an argument" - } - - val serverUrl = args.last() - logger.info { "Connecting to test server at: $serverUrl" } - - val httpClient = HttpClient(CIO) { - install(SSE) - } - val transport: Transport = StreamableHttpClientTransport(httpClient, serverUrl) - - val client = Client( - clientInfo = Implementation( - name = "kotlin-conformance-client", - version = "1.0.0", - ), - ) - - var exitCode = 0 - - runBlocking { - try { - client.connect(transport) - logger.info { "✅ Connected to server successfully" } - - try { - val tools = client.listTools() - logger.info { "Available tools: ${tools.tools.map { it.name }}" } - - if (tools.tools.isNotEmpty()) { - val toolName = tools.tools.first().name - logger.info { "Calling tool: $toolName" } - - val result = client.callTool( - CallToolRequest( - params = CallToolRequestParams( - name = toolName, - arguments = buildJsonObject { - put("input", JsonPrimitive("test")) - }, - ), - ), - ) - logger.info { "Tool result: ${result.content}" } - } - } catch (e: Exception) { - logger.debug(e) { "Error during tool operations (may be expected for some scenarios)" } - } - - logger.info { "✅ Client operations completed successfully" } - } catch (e: Exception) { - logger.error(e) { "❌ Client failed" } - exitCode = 1 - } finally { - try { - transport.close() - } catch (e: Exception) { - logger.warn(e) { "Error closing transport" } - } - httpClient.close() - } - } - - kotlin.system.exitProcess(exitCode) -} diff --git a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt deleted file mode 100644 index b3331ddc9..000000000 --- a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceServer.kt +++ /dev/null @@ -1,426 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.conformance - -import io.github.oshai.kotlinlogging.KotlinLogging -import io.ktor.http.ContentType -import io.ktor.http.HttpStatusCode -import io.ktor.server.application.ApplicationCall -import io.ktor.server.application.install -import io.ktor.server.cio.CIO -import io.ktor.server.engine.embeddedServer -import io.ktor.server.request.header -import io.ktor.server.request.receiveText -import io.ktor.server.response.header -import io.ktor.server.response.respond -import io.ktor.server.response.respondText -import io.ktor.server.response.respondTextWriter -import io.ktor.server.routing.delete -import io.ktor.server.routing.get -import io.ktor.server.routing.post -import io.ktor.server.routing.routing -import io.ktor.server.websocket.WebSockets -import io.ktor.server.websocket.webSocket -import io.modelcontextprotocol.kotlin.sdk.server.Server -import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions -import io.modelcontextprotocol.kotlin.sdk.server.WebSocketMcpServerTransport -import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport -import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions -import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult -import io.modelcontextprotocol.kotlin.sdk.types.Implementation -import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError -import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage -import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest -import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse -import io.modelcontextprotocol.kotlin.sdk.types.McpJson -import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument -import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage -import io.modelcontextprotocol.kotlin.sdk.types.RPCError -import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult -import io.modelcontextprotocol.kotlin.sdk.types.RequestId -import io.modelcontextprotocol.kotlin.sdk.types.Role -import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.types.TextContent -import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents -import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.launch -import kotlinx.coroutines.withTimeoutOrNull -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonNull -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.JsonPrimitive -import kotlinx.serialization.json.buildJsonObject -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.decodeFromJsonElement -import kotlinx.serialization.json.jsonPrimitive -import kotlinx.serialization.json.put -import java.util.UUID -import java.util.concurrent.ConcurrentHashMap - -private val logger = KotlinLogging.logger {} -private val serverTransports = ConcurrentHashMap() -private val jsonFormat = Json { ignoreUnknownKeys = true } - -private const val SESSION_CREATION_TIMEOUT_MS = 2000L -private const val REQUEST_TIMEOUT_MS = 10_000L -private const val MESSAGE_QUEUE_CAPACITY = 256 - -private fun isInitializeRequest(json: JsonElement): Boolean = - json is JsonObject && json["method"]?.jsonPrimitive?.contentOrNull == "initialize" - -@Suppress("CyclomaticComplexMethod", "LongMethod") -fun main(args: Array) { - val port = args.getOrNull(0)?.toIntOrNull() ?: 3000 - - logger.info { "Starting MCP Conformance Server on port $port" } - - embeddedServer(CIO, port = port, host = "127.0.0.1") { - install(WebSockets) - - routing { - webSocket("/ws") { - logger.info { "WebSocket connection established" } - val transport = WebSocketMcpServerTransport(this) - val server = createConformanceServer() - - try { - server.createSession(transport) - } catch (e: Exception) { - logger.error(e) { "Error in WebSocket session" } - throw e - } - } - - get("/mcp") { - val sessionId = call.request.header("mcp-session-id") - ?: run { - call.respond(HttpStatusCode.BadRequest, "Missing mcp-session-id header") - return@get - } - val transport = serverTransports[sessionId] - ?: run { - call.respond(HttpStatusCode.BadRequest, "Invalid mcp-session-id") - return@get - } - transport.stream(call) - } - - post("/mcp") { - val sessionId = call.request.header("mcp-session-id") - val requestBody = call.receiveText() - - logger.debug { "Received request with sessionId: $sessionId" } - logger.trace { "Request body: $requestBody" } - - val jsonElement = try { - jsonFormat.parseToJsonElement(requestBody) - } catch (e: Exception) { - logger.error(e) { "Failed to parse request body as JSON" } - call.respond( - HttpStatusCode.BadRequest, - jsonFormat.encodeToString( - JsonObject.serializer(), - buildJsonObject { - put("jsonrpc", "2.0") - put( - "error", - buildJsonObject { - put("code", -32700) - put("message", "Parse error: ${e.message}") - }, - ) - put("id", JsonNull) - }, - ), - ) - return@post - } - - val transport = sessionId?.let { serverTransports[it] } - if (transport != null) { - logger.debug { "Using existing transport for session: $sessionId" } - transport.handleRequest(call, jsonElement) - } else { - if (isInitializeRequest(jsonElement)) { - val newSessionId = UUID.randomUUID().toString() - logger.info { "Creating new session with ID: $newSessionId" } - - val newTransport = HttpServerTransport(newSessionId) - serverTransports[newSessionId] = newTransport - - val mcpServer = createConformanceServer() - call.response.header("mcp-session-id", newSessionId) - - val sessionReady = CompletableDeferred() - CoroutineScope(Dispatchers.IO).launch { - try { - mcpServer.createSession(newTransport) - sessionReady.complete(Unit) - } catch (e: Exception) { - logger.error(e) { "Failed to create session" } - serverTransports.remove(newSessionId) - newTransport.close() - sessionReady.completeExceptionally(e) - } - } - - val sessionCreated = withTimeoutOrNull(SESSION_CREATION_TIMEOUT_MS) { - sessionReady.await() - } - - if (sessionCreated == null) { - logger.error { "Session creation timed out" } - serverTransports.remove(newSessionId) - call.respond( - HttpStatusCode.InternalServerError, - jsonFormat.encodeToString( - JsonObject.serializer(), - buildJsonObject { - put("jsonrpc", "2.0") - put( - "error", - buildJsonObject { - put("code", -32000) - put("message", "Session creation timed out") - }, - ) - put("id", JsonNull) - }, - ), - ) - return@post - } - - newTransport.handleRequest(call, jsonElement) - } else { - logger.warn { "Invalid request: no session ID or not an initialization request" } - call.respond( - HttpStatusCode.BadRequest, - jsonFormat.encodeToString( - JsonObject.serializer(), - buildJsonObject { - put("jsonrpc", "2.0") - put( - "error", - buildJsonObject { - put("code", -32000) - put("message", "Bad Request: No valid session ID provided") - }, - ) - put("id", JsonNull) - }, - ), - ) - } - } - } - - delete("/mcp") { - val sessionId = call.request.header("mcp-session-id") - val transport = sessionId?.let { serverTransports[it] } - if (transport != null) { - logger.info { "Terminating session: $sessionId" } - serverTransports.remove(sessionId) - transport.close() - call.respond(HttpStatusCode.OK) - } else { - logger.warn { "Invalid session termination request: $sessionId" } - call.respond(HttpStatusCode.BadRequest, "Invalid or missing session ID") - } - } - } - }.start(wait = true) -} - -@Suppress("LongMethod") -private fun createConformanceServer(): Server { - val server = Server( - Implementation( - name = "kotlin-conformance-server", - version = "1.0.0", - ), - ServerOptions( - capabilities = ServerCapabilities( - tools = ServerCapabilities.Tools(listChanged = true), - resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), - prompts = ServerCapabilities.Prompts(listChanged = true), - ), - ), - ) - - server.addTool( - name = "test-tool", - description = "A test tool for conformance testing", - inputSchema = ToolSchema( - properties = buildJsonObject { - put( - "input", - buildJsonObject { - put("type", "string") - put("description", "Test input parameter") - }, - ) - }, - required = listOf("input"), - ), - ) { request -> - val input = (request.params.arguments?.get("input") as? JsonPrimitive)?.content ?: "no input" - CallToolResult( - content = listOf(TextContent("Tool executed with input: $input")), - ) - } - - server.addResource( - uri = "test://test-resource", - name = "Test Resource", - description = "A test resource for conformance testing", - mimeType = "text/plain", - ) { request -> - ReadResourceResult( - contents = listOf( - TextResourceContents("Test resource content", request.params.uri, "text/plain"), - ), - ) - } - - server.addPrompt( - name = "test-prompt", - description = "A test prompt for conformance testing", - arguments = listOf( - PromptArgument( - name = "arg", - description = "Test argument", - required = false, - ), - ), - ) { - GetPromptResult( - messages = listOf( - PromptMessage( - role = Role.User, - content = TextContent("Test prompt content"), - ), - ), - description = "Test prompt description", - ) - } - - return server -} - -private class HttpServerTransport(private val sessionId: String) : AbstractTransport() { - private val logger = KotlinLogging.logger {} - private val pendingResponses = ConcurrentHashMap>() - private val messageQueue = Channel(MESSAGE_QUEUE_CAPACITY) - - suspend fun stream(call: ApplicationCall) { - logger.debug { "Starting SSE stream for session $sessionId" } - call.response.apply { - header("Cache-Control", "no-cache") - header("Connection", "keep-alive") - } - call.respondTextWriter(ContentType.Text.EventStream) { - try { - while (true) { - val msg = messageQueue.receiveCatching().getOrNull() ?: break - write("event: message\ndata: ${McpJson.encodeToString(msg)}\n\n") - flush() - } - } catch (e: Exception) { - logger.warn(e) { "SSE stream terminated for session $sessionId" } - } finally { - logger.debug { "SSE stream closed for session $sessionId" } - } - } - } - - suspend fun handleRequest(call: ApplicationCall, requestBody: JsonElement) { - try { - val message = McpJson.decodeFromJsonElement(requestBody) - logger.debug { "Handling ${message::class.simpleName}: $requestBody" } - - when (message) { - is JSONRPCRequest -> { - val idKey = when (val id = message.id) { - is RequestId.NumberId -> id.value.toString() - is RequestId.StringId -> id.value - } - val responseDeferred = CompletableDeferred() - pendingResponses[idKey] = responseDeferred - - _onMessage.invoke(message) - - val response = withTimeoutOrNull(REQUEST_TIMEOUT_MS) { responseDeferred.await() } - if (response != null) { - call.respondText(McpJson.encodeToString(response), ContentType.Application.Json) - } else { - pendingResponses.remove(idKey) - logger.warn { "Timeout for request $idKey" } - call.respondText( - McpJson.encodeToString( - JSONRPCError( - message.id, - RPCError(RPCError.ErrorCode.REQUEST_TIMEOUT, "Request timed out"), - ), - ), - ContentType.Application.Json, - ) - } - } - - else -> { - call.respond(HttpStatusCode.Accepted) - } - } - } catch (e: CancellationException) { - throw e - } catch (e: Exception) { - logger.error(e) { "Error handling request" } - if (!call.response.isCommitted) { - call.respondText( - McpJson.encodeToString( - JSONRPCError( - RequestId(0), - RPCError(RPCError.ErrorCode.INTERNAL_ERROR, "Internal error: ${e.message}"), - ), - ), - ContentType.Application.Json, - HttpStatusCode.InternalServerError, - ) - } - } - } - - override suspend fun start() { - logger.debug { "Started transport for session $sessionId" } - } - - override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { - when (message) { - is JSONRPCResponse -> { - val idKey = when (val id = message.id) { - is RequestId.NumberId -> id.value.toString() - is RequestId.StringId -> id.value - } - pendingResponses.remove(idKey)?.complete(message) ?: run { - logger.warn { "No pending response for ID $idKey, queueing" } - messageQueue.send(message) - } - } - - else -> messageQueue.send(message) - } - } - - override suspend fun close() { - logger.debug { "Closing transport for session $sessionId" } - messageQueue.close() - pendingResponses.clear() - invokeOnCloseCallback() - } -} diff --git a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTest.kt b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTest.kt deleted file mode 100644 index 41c9c3cd3..000000000 --- a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTest.kt +++ /dev/null @@ -1,343 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.conformance - -import io.github.oshai.kotlinlogging.KotlinLogging -import io.modelcontextprotocol.kotlin.test.utils.NPX -import io.modelcontextprotocol.kotlin.test.utils.findFreePort -import io.modelcontextprotocol.kotlin.test.utils.startLogging -import org.junit.jupiter.api.AfterAll -import org.junit.jupiter.api.BeforeAll -import org.junit.jupiter.api.DynamicTest -import org.junit.jupiter.api.TestFactory -import org.junit.jupiter.api.TestInstance -import java.io.BufferedReader -import java.io.InputStreamReader -import java.lang.management.ManagementFactory -import java.net.HttpURLConnection -import java.net.URI -import java.util.concurrent.TimeUnit -import kotlin.io.path.createTempFile -import kotlin.properties.Delegates -import kotlin.test.fail - -private val logger = KotlinLogging.logger {} - -val processStderrLogger = KotlinLogging.logger(name = "stderr") -val processStdoutLogger = KotlinLogging.logger(name = "stdout") - -private const val CONFORMANCE_VERSION = "0.1.8" - -enum class TransportType { - SSE, - WEBSOCKET, -} - -@TestInstance(TestInstance.Lifecycle.PER_CLASS) -class ConformanceTest { - - private var serverProcess: Process? = null - private var serverPort: Int by Delegates.notNull() - private val serverErrorOutput = mutableListOf() - private val maxErrorLines = 500 - - companion object { - private val SERVER_SCENARIOS = listOf( - "server-initialize", - "tools-list", - "tools-call-simple-text", - "resources-list", - "prompts-list", - // TODO: Fix - // - resources-read-text - // - prompts-get-simple - ) - - private val CLIENT_SCENARIOS = listOf( - "initialize", - // TODO: Fix - // "tools-call", - ) - - private val SERVER_TRANSPORT_TYPES = listOf( - TransportType.SSE, - // TODO: Fix -// TransportType.WEBSOCKET, - ) - - private val CLIENT_TRANSPORT_TYPES = listOf( - TransportType.SSE, - TransportType.WEBSOCKET, - ) - - private const val DEFAULT_TEST_TIMEOUT_SECONDS = 30L - private const val DEFAULT_SERVER_STARTUP_TIMEOUT_SECONDS = 10 - private const val INITIAL_BACKOFF_MS = 50L - private const val MAX_BACKOFF_MS = 500L - private const val BACKOFF_MULTIPLIER = 1.5 - private const val CONNECTION_TIMEOUT_MS = 500 - private const val GRACEFUL_SHUTDOWN_SECONDS = 5L - private const val FORCE_SHUTDOWN_SECONDS = 2L - - private fun getRuntimeClasspath(): String = ManagementFactory.getRuntimeMXBean().classPath - - private fun getTestClasspath(): String = System.getProperty("test.classpath") ?: getRuntimeClasspath() - - private fun waitForServerReady( - url: String, - timeoutSeconds: Int = DEFAULT_SERVER_STARTUP_TIMEOUT_SECONDS, - ): Boolean { - val deadline = System.currentTimeMillis() + (timeoutSeconds * 1000) - var lastError: Exception? = null - var backoffMs = INITIAL_BACKOFF_MS - - while (System.currentTimeMillis() < deadline) { - try { - val connection = URI(url).toURL().openConnection() as HttpURLConnection - connection.requestMethod = "GET" - connection.connectTimeout = CONNECTION_TIMEOUT_MS - connection.readTimeout = CONNECTION_TIMEOUT_MS - connection.connect() - - val responseCode = connection.responseCode - connection.disconnect() - logger.debug { "Server responded with code: $responseCode" } - return true - } catch (e: Exception) { - lastError = e - Thread.sleep(backoffMs) - backoffMs = (backoffMs * BACKOFF_MULTIPLIER).toLong().coerceAtMost(MAX_BACKOFF_MS) - } - } - - logger.error { "Server did not start within $timeoutSeconds seconds. Last error: ${lastError?.message}" } - return false - } - } - - @BeforeAll - fun startServer() { - serverPort = findFreePort() - val serverUrl = "http://127.0.0.1:$serverPort/mcp" - - logger.info { "Starting conformance test server on port $serverPort" } - - val processBuilder = ProcessBuilder( - "java", - "-cp", - getRuntimeClasspath(), - "io.modelcontextprotocol.kotlin.sdk.conformance.ConformanceServerKt", - serverPort.toString(), - ) - - val process = processBuilder.start() - serverProcess = process - - // capture stderr in the background - Thread { - try { - BufferedReader(InputStreamReader(process.errorStream)).use { reader -> - reader.lineSequence().forEach { line -> - synchronized(serverErrorOutput) { - if (serverErrorOutput.size >= maxErrorLines) { - serverErrorOutput.removeAt(0) - } - serverErrorOutput.add(line) - } - logger.debug { "Server stderr: $line" } - } - } - } catch (e: Exception) { - logger.trace(e) { "Error reading server stderr" } - } - }.apply { - name = "server-stderr-reader" - isDaemon = true - }.start() - - logger.info { "Waiting for server to start..." } - val serverReady = waitForServerReady(serverUrl) - - if (!serverReady) { - val errorInfo = synchronized(serverErrorOutput) { - if (serverErrorOutput.isNotEmpty()) { - "\n\nServer error output:\n${serverErrorOutput.joinToString("\n")}" - } else { - "" - } - } - serverProcess?.destroyForcibly() - throw IllegalStateException( - "Server failed to start within $DEFAULT_SERVER_STARTUP_TIMEOUT_SECONDS seconds. " + - "Check if port $serverPort is available.$errorInfo", - ) - } - - logger.info { "Server started successfully at $serverUrl" } - } - - @AfterAll - fun stopServer() { - serverProcess?.also { process -> - logger.info { "Stopping conformance test server (PID: ${process.pid()})" } - - try { - process.destroy() - val terminated = process.waitFor(GRACEFUL_SHUTDOWN_SECONDS, TimeUnit.SECONDS) - - if (!terminated) { - logger.warn { "Server did not terminate gracefully, forcing shutdown..." } - process.destroyForcibly() - process.waitFor(FORCE_SHUTDOWN_SECONDS, TimeUnit.SECONDS) - } else { - logger.info { "Server stopped gracefully" } - } - } catch (e: Exception) { - logger.error(e) { "Error stopping server process" } - } finally { - serverProcess = null - } - } ?: logger.debug { "No server process to stop" } - } - - @TestFactory - fun `MCP Server Conformance Tests`(): List = SERVER_TRANSPORT_TYPES.flatMap { transportType -> - SERVER_SCENARIOS.map { scenario -> - DynamicTest.dynamicTest("Server [$transportType]: $scenario") { - runServerConformanceTest(scenario, transportType) - } - } - } - - @TestFactory - fun `MCP Client Conformance Tests`(): List = CLIENT_TRANSPORT_TYPES.flatMap { transportType -> - CLIENT_SCENARIOS.map { scenario -> - DynamicTest.dynamicTest("Client [$transportType]: $scenario") { - runClientConformanceTest(scenario, transportType) - } - } - } - - private fun runServerConformanceTest(scenario: String, transportType: TransportType) { - val serverUrl = when (transportType) { - TransportType.SSE -> { - "http://127.0.0.1:$serverPort/mcp" - } - - TransportType.WEBSOCKET -> { - "ws://127.0.0.1:$serverPort/ws" - } - } - - val processBuilder = ProcessBuilder( - NPX, - "@modelcontextprotocol/conformance@$CONFORMANCE_VERSION", - "server", - "--url", - serverUrl, - "--scenario", - scenario, - ) - - runConformanceTest("server", scenario, processBuilder, transportType) - } - - private fun runClientConformanceTest(scenario: String, transportType: TransportType) { - val testClasspath = getTestClasspath() - - // Create an argfile to avoid Windows command line length limits - val argFile = createTempFile(suffix = ".args").toFile() - argFile.deleteOnExit() - - val mainClass = when (transportType) { - TransportType.SSE -> { - argFile.writeText( - buildString { - appendLine("-cp") - appendLine(testClasspath) - appendLine("io.modelcontextprotocol.kotlin.sdk.conformance.ConformanceClientKt") - }, - ) - "http://127.0.0.1:$serverPort/mcp" - } - - TransportType.WEBSOCKET -> { - argFile.writeText( - buildString { - appendLine("-cp") - appendLine(testClasspath) - appendLine("io.modelcontextprotocol.kotlin.sdk.conformance.WebSocketConformanceClientKt") - }, - ) - "ws://127.0.0.1:$serverPort/ws" - } - } - - val clientCommand = listOf( - "java", - "@${argFile.absolutePath}", - mainClass, - ) - - val processBuilder = ProcessBuilder( - NPX, - "@modelcontextprotocol/conformance@$CONFORMANCE_VERSION", - "client", - "--command", - clientCommand.joinToString(" "), - "--scenario", - scenario, - ) - - runConformanceTest("client", scenario, processBuilder, transportType) - } - - private fun runConformanceTest( - type: String, - scenario: String, - processBuilder: ProcessBuilder, - transportType: TransportType, - ) { - val capitalizedType = type.replaceFirstChar { it.uppercase() } - logger.info { "Running $type conformance test [$transportType]: $scenario" } - - val timeoutSeconds = - System.getenv("CONFORMANCE_TEST_TIMEOUT_SECONDS")?.toLongOrNull() ?: DEFAULT_TEST_TIMEOUT_SECONDS - - val process = processBuilder.start() - - process.errorStream.startLogging( - logger = processStderrLogger, - name = "test(PID=${process.pid()})", - ) - process.inputStream.startLogging( - logger = processStdoutLogger, - name = "test(PID=${process.pid()})", - ) - - val completed = process.waitFor(timeoutSeconds, TimeUnit.SECONDS) - - if (!completed) { - logger.error { - "$capitalizedType conformance test [$transportType] '$scenario' timed out after $timeoutSeconds seconds" - } - process.destroyForcibly() - throw AssertionError( - "❌ $capitalizedType conformance test [$transportType] '$scenario' " + - "timed out after $timeoutSeconds seconds", - ) - } - - when (val exitCode = process.exitValue()) { - 0 -> logger.info { "✅ $capitalizedType conformance test [$transportType] '$scenario' passed!" } - - else -> { - logger.error { - "$capitalizedType conformance test [$transportType] '$scenario' failed with exit code: $exitCode" - } - fail( - "❌ $capitalizedType conformance test [$transportType] '$scenario' " + - "failed (exit code: $exitCode). Check test output above for details.", - ) - } - } - } -} diff --git a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/WebSocketConformanceClient.kt b/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/WebSocketConformanceClient.kt deleted file mode 100644 index f385dddbc..000000000 --- a/conformance-test/src/test/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/WebSocketConformanceClient.kt +++ /dev/null @@ -1,105 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.conformance - -import io.github.oshai.kotlinlogging.KotlinLogging -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.websocket.WebSockets -import io.ktor.client.plugins.websocket.webSocket -import io.ktor.websocket.WebSocketSession -import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.shared.MCP_SUBPROTOCOL -import io.modelcontextprotocol.kotlin.sdk.shared.WebSocketMcpTransport -import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest -import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams -import io.modelcontextprotocol.kotlin.sdk.types.Implementation -import kotlinx.coroutines.runBlocking -import kotlinx.serialization.json.JsonPrimitive -import kotlinx.serialization.json.buildJsonObject - -private val logger = KotlinLogging.logger {} - -class WebSocketClientTransport(override val session: WebSocketSession) : WebSocketMcpTransport() { - override suspend fun initializeSession() { - logger.debug { "WebSocket client session initialized" } - } -} - -@Suppress("LongMethod") -fun main(args: Array) { - require(args.isNotEmpty()) { - "Server WebSocket URL must be provided as an argument" - } - - val serverUrl = args.last() - logger.info { "Connecting to WebSocket test server at: $serverUrl" } - - val httpClient = HttpClient(CIO) { - install(WebSockets) - } - - var exitCode = 0 - - runBlocking { - try { - httpClient.webSocket(serverUrl, request = { - headers.append("Sec-WebSocket-Protocol", MCP_SUBPROTOCOL) - }) { - val transport = WebSocketClientTransport(this) - - val client = Client( - clientInfo = Implementation( - name = "kotlin-conformance-client-websocket", - version = "1.0.0", - ), - ) - - try { - client.connect(transport) - logger.info { "✅ Connected to server successfully" } - - try { - val tools = client.listTools() - logger.info { "Available tools: ${tools.tools.map { it.name }}" } - - if (tools.tools.isNotEmpty()) { - val toolName = tools.tools.first().name - logger.info { "Calling tool: $toolName" } - - val result = client.callTool( - CallToolRequest( - params = CallToolRequestParams( - name = toolName, - arguments = buildJsonObject { - put("input", JsonPrimitive("test")) - }, - ), - ), - ) - logger.info { "Tool result: ${result.content}" } - } - } catch (e: Exception) { - logger.debug(e) { "Error during tool operations (may be expected for some scenarios)" } - } - - logger.info { "✅ Client operations completed successfully" } - } catch (e: Exception) { - logger.error(e) { "❌ Client failed" } - exitCode = 1 - } finally { - try { - transport.close() - } catch (e: Exception) { - logger.warn(e) { "Error closing transport" } - } - } - } - } catch (e: Exception) { - logger.error(e) { "❌ WebSocket connection failed" } - exitCode = 1 - } finally { - httpClient.close() - } - } - - kotlin.system.exitProcess(exitCode) -} diff --git a/conformance-test/src/test/resources/simplelogger.properties b/conformance-test/src/test/resources/simplelogger.properties deleted file mode 100644 index c6f3b90c7..000000000 --- a/conformance-test/src/test/resources/simplelogger.properties +++ /dev/null @@ -1,10 +0,0 @@ -# Level of logging for the ROOT logger: ERROR, WARN, INFO, DEBUG, TRACE (default is INFO) -org.slf4j.simpleLogger.defaultLogLevel=INFO -org.slf4j.simpleLogger.showThreadName=true -org.slf4j.simpleLogger.showDateTime=false - -# Log level for specific packages or classes -org.slf4j.simpleLogger.log.io.ktor.server=DEBUG -org.slf4j.simpleLogger.log.stdout=TRACE -org.slf4j.simpleLogger.log.stderr=TRACE -org.slf4j.simpleLogger.log.io.modelcontextprotocol=DEBUG diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index f86685d37..313c8f6a5 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -44,6 +44,7 @@ kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx- # Ktor ktor-client-apache5 = { group = "io.ktor", name = "ktor-client-apache5", version.ref = "ktor" } +ktor-client-auth = { group = "io.ktor", name = "ktor-client-auth", version.ref = "ktor" } ktor-client-core = { group = "io.ktor", name = "ktor-client-core", version.ref = "ktor" } ktor-client-logging = { group = "io.ktor", name = "ktor-client-logging", version.ref = "ktor" } ktor-server-content-negotiation = { group = "io.ktor", name = "ktor-server-content-negotiation", version.ref = "ktor" } diff --git a/kotlin-sdk-client/api/kotlin-sdk-client.api b/kotlin-sdk-client/api/kotlin-sdk-client.api index 268a2b8df..da0e752a7 100644 --- a/kotlin-sdk-client/api/kotlin-sdk-client.api +++ b/kotlin-sdk-client/api/kotlin-sdk-client.api @@ -62,6 +62,18 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/KtorClientKt { public static synthetic fun mcpSseTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/SseClientTransport; } +public final class io/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions { + public synthetic fun (JJDIILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (JJDILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun equals (Ljava/lang/Object;)Z + public final fun getInitialReconnectionDelay-UwyO8pc ()J + public final fun getMaxReconnectionDelay-UwyO8pc ()J + public final fun getMaxRetries ()I + public final fun getReconnectionDelayMultiplier ()D + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + public final class io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport { public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V @@ -88,6 +100,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/StdioClientTranspor } public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport { + public fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getProtocolVersion ()Ljava/lang/String; @@ -106,8 +120,12 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpError } public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensionsKt { + public static final fun mcpStreamableHttp (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public static final fun mcpStreamableHttp-BZiP2OM (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun mcpStreamableHttp-BZiP2OM$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpStreamableHttpTransport (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; + public static synthetic fun mcpStreamableHttpTransport$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; public static final fun mcpStreamableHttpTransport-5_5nbZA (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; public static synthetic fun mcpStreamableHttpTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions.kt new file mode 100644 index 000000000..95c5bdefa --- /dev/null +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions.kt @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds + +/** + * Options for controlling SSE reconnection behavior. + * + * @property initialReconnectionDelay The initial delay before the first reconnection attempt. + * @property maxReconnectionDelay The maximum delay between reconnection attempts. + * @property reconnectionDelayMultiplier The factor by which the delay grows on each attempt. + * @property maxRetries The maximum number of reconnection attempts per disconnect. + */ +public class ReconnectionOptions( + public val initialReconnectionDelay: Duration = 1.seconds, + public val maxReconnectionDelay: Duration = 30.seconds, + public val reconnectionDelayMultiplier: Double = 1.5, + public val maxRetries: Int = 2, +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as ReconnectionOptions + + if (reconnectionDelayMultiplier != other.reconnectionDelayMultiplier) return false + if (maxRetries != other.maxRetries) return false + if (initialReconnectionDelay != other.initialReconnectionDelay) return false + if (maxReconnectionDelay != other.maxReconnectionDelay) return false + + return true + } + + override fun hashCode(): Int { + var result = reconnectionDelayMultiplier.hashCode() + result = 31 * result + maxRetries + result = 31 * result + initialReconnectionDelay.hashCode() + result = 31 * result + maxReconnectionDelay.hashCode() + return result + } + + override fun toString(): String = + "ReconnectionOptions(initialReconnectionDelay=$initialReconnectionDelay, maxReconnectionDelay=$maxReconnectionDelay, reconnectionDelayMultiplier=$reconnectionDelayMultiplier, maxRetries=$maxRetries)" +} diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index a3b41906e..80f37aedc 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -38,9 +38,13 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.delay +import kotlinx.coroutines.isActive import kotlinx.coroutines.launch -import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.math.pow import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds private const val MCP_SESSION_ID_HEADER = "mcp-session-id" private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" @@ -52,31 +56,58 @@ private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" public class StreamableHttpError(public val code: Int? = null, message: String? = null) : Exception("Streamable HTTP error: $message") +private sealed interface ConnectResult { + data class Success(val session: ClientSSESession) : ConnectResult + data object NonRetryable : ConnectResult + data object Failed : ConnectResult +} + /** * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It will connect to a server using HTTP POST for sending messages and HTTP GET with Server-Sent Events * for receiving messages. */ -@OptIn(ExperimentalAtomicApi::class) +@Suppress("TooManyFunctions") public class StreamableHttpClientTransport( private val client: HttpClient, private val url: String, - private val reconnectionTime: Duration? = null, + private val reconnectionOptions: ReconnectionOptions = ReconnectionOptions(), private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, ) : AbstractClientTransport() { + @Deprecated( + "Use constructor with ReconnectionOptions", + replaceWith = ReplaceWith( + "StreamableHttpClientTransport(client, url, " + + "ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder)", + "kotlin.time.Duration.Companion.seconds", + "io.modelcontextprotocol.kotlin.sdk.client.ReconnectionOptions", + ), + ) + public constructor( + client: HttpClient, + url: String, + reconnectionTime: Duration?, + requestBuilder: HttpRequestBuilder.() -> Unit = {}, + ) : this(client, url, ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder) + override val logger: KLogger = KotlinLogging.logger {} public var sessionId: String? = null private set public var protocolVersion: String? = null - private var sseSession: ClientSSESession? = null private var sseJob: Job? = null private val scope by lazy { CoroutineScope(SupervisorJob() + Dispatchers.Default) } - private var lastEventId: String? = null + /** Result of an SSE stream collection. Reconnect when [hasPrimingEvent] is true and [receivedResponse] is false. */ + private data class SseStreamResult( + val hasPrimingEvent: Boolean, + val receivedResponse: Boolean, + val lastEventId: String? = null, + val serverRetryDelay: Duration? = null, + ) override suspend fun initialize() { logger.debug { "Client transport is starting..." } @@ -85,7 +116,7 @@ public class StreamableHttpClientTransport( /** * Sends a single message with optional resumption support */ - @Suppress("ReturnCount", "CyclomaticComplexMethod") + @Suppress("ReturnCount", "CyclomaticComplexMethod", "LongMethod", "TooGenericExceptionCaught", "ThrowsCount") override suspend fun performSend(message: JSONRPCMessage, options: TransportSendOptions?) { logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" } @@ -133,18 +164,25 @@ public class StreamableHttpClientTransport( } } - ContentType.Text.EventStream -> handleInlineSse( - response, - onResumptionToken = options?.onResumptionToken, - replayMessageId = if (message is JSONRPCRequest) message.id else null, - ) + ContentType.Text.EventStream -> { + val replayMessageId = if (message is JSONRPCRequest) message.id else null + val result = handleInlineSse(response, replayMessageId, options?.onResumptionToken) + if (result.hasPrimingEvent && !result.receivedResponse) { + startSseSession( + resumptionToken = result.lastEventId, + replayMessageId = replayMessageId, + onResumptionToken = options?.onResumptionToken, + initialServerRetryDelay = result.serverRetryDelay, + ) + } + } else -> { val body = response.bodyAsText() if (response.contentType() == null && body.isBlank()) return val ct = response.contentType()?.toString() ?: "" - val error = StreamableHttpError(-1, "Unexpected content type: $$ct") + val error = StreamableHttpError(-1, "Unexpected content type: $ct") _onError(error) throw error } @@ -169,11 +207,6 @@ public class StreamableHttpClientTransport( override suspend fun closeResources() { logger.debug { "Client transport closing." } - - // Try to terminate session if we have one - terminateSession() - - sseSession?.cancel() sseJob?.cancelAndJoin() scope.cancel() } @@ -201,55 +234,120 @@ public class StreamableHttpClientTransport( } sessionId = null - lastEventId = null logger.debug { "Session terminated successfully" } } - private suspend fun startSseSession( + private fun startSseSession( resumptionToken: String? = null, replayMessageId: RequestId? = null, onResumptionToken: ((String) -> Unit)? = null, + initialServerRetryDelay: Duration? = null, ) { - sseSession?.cancel() - sseJob?.cancelAndJoin() + // Cancel-and-replace: cancel() signals the previous job, join() inside + // the new coroutine ensures it completes before we start collecting. + // This is intentionally non-suspend to avoid blocking performSend. + val previousJob = sseJob + previousJob?.cancel() + sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) { + previousJob?.join() + var lastEventId = resumptionToken + var serverRetryDelay = initialServerRetryDelay + var attempt = 0 + var needsDelay = initialServerRetryDelay != null + + @Suppress("LoopWithTooManyJumpStatements") + while (isActive) { + // Delay before (re)connection: skip only for first fresh SSE connection + if (needsDelay) { + delay(getNextReconnectionDelay(attempt, serverRetryDelay)) + } + needsDelay = true + + // Connect + val session = when (val cr = connectSse(lastEventId)) { + is ConnectResult.Success -> { + attempt = 0 + cr.session + } + + ConnectResult.NonRetryable -> return@launch + + ConnectResult.Failed -> { + // Give up after maxRetries consecutive failed connection attempts + if (++attempt >= reconnectionOptions.maxRetries) { + _onError(StreamableHttpError(null, "Maximum reconnection attempts exceeded")) + return@launch + } + continue + } + } + + // Collect + val result = collectSse(session, replayMessageId, onResumptionToken) + lastEventId = result.lastEventId ?: lastEventId + serverRetryDelay = result.serverRetryDelay ?: serverRetryDelay + if (result.receivedResponse) break + } + } + } + @Suppress("TooGenericExceptionCaught") + private suspend fun connectSse(lastEventId: String?): ConnectResult { logger.debug { "Client attempting to start SSE session at url: $url" } - try { - sseSession = client.sseSession( - urlString = url, - reconnectionTime = reconnectionTime, - ) { + return try { + val session = client.sseSession(urlString = url, showRetryEvents = true) { method = HttpMethod.Get applyCommonHeaders(this) - // sseSession will add ContentType.Text.EventStream automatically accept(ContentType.Application.Json) - (resumptionToken ?: lastEventId)?.let { headers.append(MCP_RESUMPTION_TOKEN_HEADER, it) } + lastEventId?.let { headers.append(MCP_RESUMPTION_TOKEN_HEADER, it) } requestBuilder() } logger.debug { "Client SSE session started successfully." } + ConnectResult.Success(session) + } catch (e: CancellationException) { + throw e } catch (e: SSEClientException) { - val responseStatus = e.response?.status - val responseContentType = e.response?.contentType() + if (isNonRetryableSseError(e)) { + ConnectResult.NonRetryable + } else { + logger.debug { "SSE connection failed: ${e.message}" } + ConnectResult.Failed + } + } catch (e: Exception) { + logger.debug { "SSE connection failed: ${e.message}" } + ConnectResult.Failed + } + } - // 405 means server doesn't support SSE at GET endpoint - this is expected and valid - if (responseStatus == HttpStatusCode.MethodNotAllowed) { - logger.info { "Server returned 405 for GET/SSE, stream disabled." } - return + private fun getNextReconnectionDelay(attempt: Int, serverRetryDelay: Duration?): Duration { + // Per SSE specification, the server-sent `retry` field sets the reconnection time + // for all subsequent attempts, taking priority over exponential backoff. + serverRetryDelay?.let { return it } + val delay = reconnectionOptions.initialReconnectionDelay * + reconnectionOptions.reconnectionDelayMultiplier.pow(attempt) + return delay.coerceAtMost(reconnectionOptions.maxReconnectionDelay) + } + + /** + * Checks if an SSE session error is non-retryable (404, 405, JSON-only). + * Returns `true` if non-retryable (should stop trying), `false` otherwise. + */ + private fun isNonRetryableSseError(e: SSEClientException): Boolean { + val responseStatus = e.response?.status + val responseContentType = e.response?.contentType() + + return when { + responseStatus == HttpStatusCode.NotFound || responseStatus == HttpStatusCode.MethodNotAllowed -> { + logger.info { "Server returned ${responseStatus.value} for GET/SSE, stream disabled." } + true } - // If server returns application/json, it means it doesn't support SSE for this session - // This is valid per spec - server can choose to only use JSON responses - if (responseContentType?.match(ContentType.Application.Json) == true) { + responseContentType?.match(ContentType.Application.Json) == true -> { logger.info { "Server returned application/json for GET/SSE, using JSON-only mode." } - return + true } - _onError(e) - throw e - } - - sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) { - sseSession?.let { collectSse(it, replayMessageId, onResumptionToken) } + else -> false } } @@ -265,11 +363,17 @@ public class StreamableHttpClientTransport( session: ClientSSESession, replayMessageId: RequestId?, onResumptionToken: ((String) -> Unit)?, - ) { + ): SseStreamResult { + var hasPrimingEvent = false + var receivedResponse = false + var localLastEventId: String? = null + var localServerRetryDelay: Duration? = null try { session.incoming.collect { event -> + event.retry?.let { localServerRetryDelay = it.milliseconds } event.id?.let { - lastEventId = it + localLastEventId = it + hasPrimingEvent = true onResumptionToken?.invoke(it) } logger.trace { "Client received SSE event: event=${event.event}, data=${event.data}, id=${event.id}" } @@ -278,6 +382,7 @@ public class StreamableHttpClientTransport( event.data?.takeIf { it.isNotEmpty() }?.let { json -> runCatching { McpJson.decodeFromString(json) } .onSuccess { msg -> + if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { _onMessage(msg.copy(id = replayMessageId)) } else { @@ -295,6 +400,7 @@ public class StreamableHttpClientTransport( } catch (t: Throwable) { _onError(t) } + return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } @Suppress("CyclomaticComplexMethod") @@ -302,17 +408,22 @@ public class StreamableHttpClientTransport( response: HttpResponse, replayMessageId: RequestId?, onResumptionToken: ((String) -> Unit)?, - ) { + ): SseStreamResult { logger.trace { "Handling inline SSE from POST response" } val channel = response.bodyAsChannel() + var hasPrimingEvent = false + var receivedResponse = false + var localLastEventId: String? = null + var localServerRetryDelay: Duration? = null val sb = StringBuilder() var id: String? = null var eventName: String? = null suspend fun dispatch(id: String?, eventName: String?, data: String) { id?.let { - lastEventId = it + localLastEventId = it + hasPrimingEvent = true onResumptionToken?.invoke(it) } if (data.isBlank()) { @@ -321,6 +432,7 @@ public class StreamableHttpClientTransport( if (eventName == null || eventName == "message") { runCatching { McpJson.decodeFromString(data) } .onSuccess { msg -> + if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { _onMessage(msg.copy(id = replayMessageId)) } else { @@ -351,9 +463,16 @@ public class StreamableHttpClientTransport( } when { line.startsWith("id:") -> id = line.substringAfter("id:").trim() + line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() + line.startsWith("data:") -> sb.append(line.substringAfter("data:").trim()) + + line.startsWith("retry:") -> line.substringAfter("retry:").trim().toLongOrNull()?.let { + localServerRetryDelay = it.milliseconds + } } } + return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt index b64a22062..a618f1823 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt @@ -6,21 +6,67 @@ import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME import io.modelcontextprotocol.kotlin.sdk.types.Implementation import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds /** * Returns a new Streamable HTTP transport for the Model Context Protocol using the provided HttpClient. * * @param url URL of the MCP server. - * @param reconnectionTime Optional duration to wait before attempting to reconnect. + * @param reconnectionOptions Options for controlling SSE reconnection behavior. * @param requestBuilder Optional lambda to configure the HTTP request. * @return A [StreamableHttpClientTransport] configured for MCP communication. */ public fun HttpClient.mcpStreamableHttpTransport( url: String, - reconnectionTime: Duration? = null, + reconnectionOptions: ReconnectionOptions = ReconnectionOptions(), requestBuilder: HttpRequestBuilder.() -> Unit = {}, ): StreamableHttpClientTransport = - StreamableHttpClientTransport(this, url, reconnectionTime, requestBuilder = requestBuilder) + StreamableHttpClientTransport(this, url, reconnectionOptions, requestBuilder = requestBuilder) + +/** + * Returns a new Streamable HTTP transport for the Model Context Protocol using the provided HttpClient. + * + * @param url URL of the MCP server. + * @param reconnectionTime Optional duration to wait before attempting to reconnect. + * @param requestBuilder Optional lambda to configure the HTTP request. + * @return A [StreamableHttpClientTransport] configured for MCP communication. + */ +@Deprecated( + "Use overload with ReconnectionOptions", + replaceWith = ReplaceWith( + "mcpStreamableHttpTransport(url, " + + "ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder)", + ), +) +public fun HttpClient.mcpStreamableHttpTransport( + url: String, + reconnectionTime: Duration?, + requestBuilder: HttpRequestBuilder.() -> Unit = {}, +): StreamableHttpClientTransport = StreamableHttpClientTransport( + this, + url, + ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), + requestBuilder = requestBuilder, +) + +/** + * Creates and connects an MCP client over Streamable HTTP using the provided HttpClient. + * + * @param url URL of the MCP server. + * @param reconnectionOptions Options for controlling SSE reconnection behavior. + * @param requestBuilder Optional lambda to configure the HTTP request. + * @return A connected [Client] ready for MCP communication. + */ +public suspend fun HttpClient.mcpStreamableHttp( + url: String, + reconnectionOptions: ReconnectionOptions = ReconnectionOptions(), + requestBuilder: HttpRequestBuilder.() -> Unit = {}, +): Client { + val transport = mcpStreamableHttpTransport(url, reconnectionOptions, requestBuilder) + val client = Client(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION)) + client.connect(transport) + return client +} /** * Creates and connects an MCP client over Streamable HTTP using the provided HttpClient. @@ -30,12 +76,23 @@ public fun HttpClient.mcpStreamableHttpTransport( * @param requestBuilder Optional lambda to configure the HTTP request. * @return A connected [Client] ready for MCP communication. */ +@Deprecated( + "Use overload with ReconnectionOptions", + replaceWith = ReplaceWith( + "mcpStreamableHttp(url, " + + "ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder)", + ), +) public suspend fun HttpClient.mcpStreamableHttp( url: String, - reconnectionTime: Duration? = null, + reconnectionTime: Duration?, requestBuilder: HttpRequestBuilder.() -> Unit = {}, ): Client { - val transport = mcpStreamableHttpTransport(url, reconnectionTime, requestBuilder) + val transport = mcpStreamableHttpTransport( + url, + ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), + requestBuilder, + ) val client = Client(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION)) client.connect(transport) return client diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt index e303326ff..55df3e06e 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt @@ -23,9 +23,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.Implementation import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay @@ -657,6 +659,203 @@ class StreamableHttpClientTransportTest { receivedErrors shouldHaveSize 0 } + @Test + fun testInlineSseRetryParsing() = runTest { + val transport = createTransport { request -> + if (request.method == HttpMethod.Post) { + val sseContent = buildString { + appendLine("retry: 5000") + appendLine("id: ev-1") + appendLine("event: message") + appendLine("""data: {"jsonrpc":"2.0","id":"req-1","result":{"tools":[]}}""") + appendLine() + } + + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), + ) + } else { + respond("", HttpStatusCode.OK) + } + } + + val receivedMessages = mutableListOf() + val responseReceived = CompletableDeferred() + + transport.onMessage { message -> + receivedMessages.add(message) + if (message is JSONRPCResponse && !responseReceived.isCompleted) { + responseReceived.complete(Unit) + } + } + + transport.start() + + transport.send( + JSONRPCRequest( + id = "req-1", + method = "test", + params = buildJsonObject { }, + ), + ) + + eventually { + responseReceived.await() + } + + receivedMessages shouldHaveSize 1 + val response = receivedMessages[0] as JSONRPCResponse + response.id shouldBe RequestId.StringId("req-1") + + transport.close() + } + + @Test + fun testInlineSseHasPrimingEventTracking() = runTest { + val transport = createTransport { request -> + if (request.method == HttpMethod.Post) { + val sseContent = buildString { + // Event with id = priming event + appendLine("id: priming-1") + appendLine("event: message") + appendLine( + """data: {"jsonrpc":"2.0","method":"notifications/progress",""" + + """"params":{"progressToken":"t1","progress":50}}""", + ) + appendLine() + // Notification without id + appendLine("event: message") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/tools/list_changed"}""") + appendLine() + } + + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), + ) + } else { + respond("", HttpStatusCode.OK) + } + } + + val receivedMessages = mutableListOf() + val twoMessagesReceived = CompletableDeferred() + + transport.onMessage { message -> + receivedMessages.add(message) + if (receivedMessages.size >= 2 && !twoMessagesReceived.isCompleted) { + twoMessagesReceived.complete(Unit) + } + } + + transport.start() + + transport.send( + JSONRPCRequest( + id = "test-1", + method = "test", + params = buildJsonObject { }, + ), + ) + + eventually { + twoMessagesReceived.await() + } + + receivedMessages shouldHaveSize 2 + // Both should be notifications (no JSONRPCResponse → POST-to-GET reconnect would be triggered) + receivedMessages[0].shouldBeInstanceOf() + receivedMessages[1].shouldBeInstanceOf() + + transport.close() + } + + @Test + fun testInlineSseResponseStopsReconnection() = runTest { + val transport = createTransport { request -> + if (request.method == HttpMethod.Post) { + val sseContent = buildString { + appendLine("id: ev-1") + appendLine("event: message") + appendLine("""data: {"jsonrpc":"2.0","id":"req-1","result":{"tools":[]}}""") + appendLine() + } + + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), + ) + } else { + respond("", HttpStatusCode.OK) + } + } + + val receivedMessages = mutableListOf() + val responseReceived = CompletableDeferred() + + transport.onMessage { message -> + receivedMessages.add(message) + if (message is JSONRPCResponse && !responseReceived.isCompleted) { + responseReceived.complete(Unit) + } + } + + transport.start() + + transport.send( + JSONRPCRequest( + id = "req-1", + method = "tools/list", + params = buildJsonObject { }, + ), + ) + + eventually { + responseReceived.await() + } + + receivedMessages shouldHaveSize 1 + // Response received → no reconnection triggered (hasPrimingEvent=true, receivedResponse=true) + val response = receivedMessages[0] as JSONRPCResponse + response.id shouldBe RequestId.StringId("req-1") + + transport.close() + } + + @Suppress("DEPRECATION") + @Test + fun testDeprecatedConstructorStillWorks() = runTest { + val mockEngine = MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.Accepted, + ) + } + val httpClient = HttpClient(mockEngine) { + install(SSE) + } + + val transport = + StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp", reconnectionTime = 2.seconds) + + transport.start() + transport.send(JSONRPCNotification(method = "test")) + transport.close() + } + private suspend fun setupTransportAndCollectMessages( transport: StreamableHttpClientTransport, ): Pair, MutableList> { diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt index 6513b469b..a1fe89e66 100644 --- a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt @@ -1,6 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.client import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.shouldBe import io.ktor.http.ContentType import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode @@ -203,7 +204,31 @@ internal class StreamableHttpClientTest : AbstractStreamableHttpClientTest() { meta = EmptyJsonObject, ) + client.close() + } + + @Test + fun `terminateSession sends DELETE request`() = runBlocking { + val client = Client( + clientInfo = Implementation(name = "client1", version = "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + val sessionId = Uuid.random().toString() + + mockMcp.onInitialize(clientName = "client1", sessionId = sessionId) + mockMcp.handleJSONRPCRequest( + jsonRpcMethod = "notifications/initialized", + expectedSessionId = sessionId, + sessionId = sessionId, + statusCode = HttpStatusCode.Accepted, + ) + mockMcp.handleSubscribeWithGet(sessionId) { emptyFlow() } + + connect(client) + mockMcp.mockUnsubscribeRequest(sessionId = sessionId) + (client.transport as StreamableHttpClientTransport).terminateSession() + (client.transport as StreamableHttpClientTransport).sessionId shouldBe null client.close() } @@ -257,8 +282,6 @@ internal class StreamableHttpClientTest : AbstractStreamableHttpClientTest() { buildJsonObject {} } - mockMcp.mockUnsubscribeRequest(sessionId = sessionId) - connect(client) delay(1.seconds) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/common.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/common.kt index 5e39d1ffb..8ddee0a9c 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/common.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/common.kt @@ -8,12 +8,13 @@ import kotlinx.serialization.json.JsonObject // Protocol Version Constants // ============================================================================ -public const val LATEST_PROTOCOL_VERSION: String = "2025-06-18" +public const val LATEST_PROTOCOL_VERSION: String = "2025-11-25" -public const val DEFAULT_NEGOTIATED_PROTOCOL_VERSION: String = "2025-03-26" +public const val DEFAULT_NEGOTIATED_PROTOCOL_VERSION: String = "2025-06-18" public val SUPPORTED_PROTOCOL_VERSIONS: List = listOf( LATEST_PROTOCOL_VERSION, + "2025-06-18", "2025-03-26", "2024-11-05", ) diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt index a00b26be4..4d5974d92 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt @@ -371,21 +371,23 @@ internal object ServerResultPolymorphicSerializer : * Polymorphic serializer for [JSONRPCMessage] types. * Determines the message type based on the presence of specific fields: * - "error" -> JSONRPCError - * - "result" -> JSONRPCResponse + * - "result" + "id" -> JSONRPCResponse + * - "result" -> JSONRPCEmptyMessage * - "method" + "id" -> JSONRPCRequest * - "method" -> JSONRPCNotification */ internal object JSONRPCMessagePolymorphicSerializer : JsonContentPolymorphicSerializer(JSONRPCMessage::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - val jsonObject = element.jsonObject + val jsonObj = element.jsonObject return when { - "error" in jsonObject -> JSONRPCError.serializer() - "result" in jsonObject -> JSONRPCResponse.serializer() - "method" in jsonObject && "id" in jsonObject -> JSONRPCRequest.serializer() - "method" in jsonObject -> JSONRPCNotification.serializer() - jsonObject.isEmpty() || jsonObject.keys == setOf("jsonrpc") -> JSONRPCEmptyMessage.serializer() - else -> throw SerializationException("Invalid JSONRPCMessage type: ${jsonObject.keys}") + "error" in jsonObj -> JSONRPCError.serializer() + "result" in jsonObj && "id" in jsonObj -> JSONRPCResponse.serializer() + "result" in jsonObj && jsonObj["result"]?.jsonObject?.isEmpty() == true -> JSONRPCEmptyMessage.serializer() + "method" in jsonObj && "id" in jsonObj -> JSONRPCRequest.serializer() + "method" in jsonObj -> JSONRPCNotification.serializer() + jsonObj.isEmpty() || jsonObj.keys == setOf("jsonrpc") -> JSONRPCEmptyMessage.serializer() + else -> throw SerializationException("Invalid JSONRPCMessage type: ${jsonObj.keys}") } } } diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/CommonTypeTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/CommonTypeTest.kt index 44e88ac37..a29a19cde 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/CommonTypeTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/CommonTypeTest.kt @@ -14,13 +14,14 @@ class CommonTypeTest { @Test fun `should have correct latest protocol version`() { - LATEST_PROTOCOL_VERSION shouldBe "2025-06-18" + LATEST_PROTOCOL_VERSION shouldBe "2025-11-25" } @Test fun `should have correct supported protocol versions`() { SUPPORTED_PROTOCOL_VERSIONS shouldContainExactlyInAnyOrder listOf( LATEST_PROTOCOL_VERSION, + "2025-06-18", "2025-03-26", "2024-11-05", ) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 580fcc0e5..9074c4c26 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -45,6 +45,7 @@ internal const val MCP_SESSION_ID_HEADER = "mcp-session-id" private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB +private const val MIN_PRIMING_EVENT_PROTOCOL_VERSION = "2025-11-25" /** * A holder for an active request call. @@ -388,7 +389,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat if (!configuration.enableJsonResponse) { call.appendSseHeaders() flushSse(session) // flush headers immediately - maybeSendPrimingEvent(streamId, session) + maybeSendPrimingEvent(streamId, session, call.request.header(MCP_PROTOCOL_VERSION_HEADER)) } streamMutex.withLock { @@ -451,7 +452,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat call.appendSseHeaders() flushSse(sseSession) // flush headers immediately streamsMapping[STANDALONE_SSE_STREAM_ID] = SessionContext(sseSession, call) - maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, sseSession) + maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, sseSession, call.request.header(MCP_PROTOCOL_VERSION_HEADER)) sseSession.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(STANDALONE_SSE_STREAM_ID) } @@ -702,12 +703,20 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } @Suppress("TooGenericExceptionCaught") - private suspend fun maybeSendPrimingEvent(streamId: String, session: ServerSSESession?) { - val store = configuration.eventStore ?: return - val sseSession = session ?: return + private suspend fun maybeSendPrimingEvent( + streamId: String, + session: ServerSSESession?, + clientProtocolVersion: String? = null, + ) { + val store = configuration.eventStore + if (store == null || session == null) return + // Priming events have empty data which older clients cannot handle. + // Only send priming events to clients with protocol version >= 2025-11-25 + // which includes the fix for handling empty SSE data. + if (clientProtocolVersion != null && clientProtocolVersion < MIN_PRIMING_EVENT_PROTOCOL_VERSION) return try { val primingEventId = store.storeEvent(streamId, JSONRPCEmptyMessage) - sseSession.send( + session.send( id = primingEventId, retry = configuration.retryInterval?.inWholeMilliseconds, data = "",