From 07a2984023957b1ef21615f8678385305abe5665 Mon Sep 17 00:00:00 2001 From: mukunda katta Date: Thu, 14 May 2026 20:07:54 -0700 Subject: [PATCH] fix: polish SEP-1577 sampling handling --- .../kotlin/sdk/conformance/ConformanceTools.kt | 8 +++++++- .../kotlin/sdk/types/sampling.kt | 8 ++++++++ .../kotlin/sdk/types/SamplingTest.kt | 18 ++++++++++++++++++ .../kotlin/sdk/server/SamplingValidation.kt | 7 +++++-- .../kotlin/sdk/server/SamplingTest.kt | 18 ++++++++++++++++++ 5 files changed, 56 insertions(+), 3 deletions(-) diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt index e288272e..c81c98c3 100644 --- a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt @@ -179,7 +179,13 @@ fun Server.registerConformanceTools() { ), ), ) - CallToolResult(listOf(TextContent(result.content.joinToString("\n") { it.toString() }))) + val sampledText = result.content.joinToString("\n") { content -> + when (content) { + is TextContent -> content.text + else -> "Non-text sampling content: ${content::class.simpleName}" + } + } + CallToolResult(listOf(TextContent(sampledText))) } // 9. Elicitation diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt index 211fdc79..45264059 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt @@ -91,6 +91,10 @@ public data class SamplingMessage( @SerialName("_meta") val meta: JsonObject? = null, ) { + init { + require(content.isNotEmpty()) { "content must contain at least one block" } + } + /** * Convenience constructor for a single-block message. Wraps [content] in a * singleton list so call sites can write `SamplingMessage(Role.User, TextContent("hi"))` @@ -273,6 +277,10 @@ public data class CreateMessageResult( @SerialName("_meta") override val meta: JsonObject? = null, ) : ClientResult { + init { + require(content.isNotEmpty()) { "content must contain at least one block" } + } + /** * Convenience constructor for a single-block response. Wraps [content] in a * singleton list so call sites can write diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt index 62a82cd4..b2c8cf0c 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt @@ -302,6 +302,13 @@ class SamplingTest { (m.content[0] as TextContent).text shouldBe "hi" } + @Test + fun `SamplingMessage rejects empty content`() { + assertFailsWith { + SamplingMessage(role = Role.User, content = emptyList()) + } + } + @Test fun `SamplingMessage single-element content serialises as single object`() { val m = SamplingMessage(role = Role.User, content = listOf(TextContent("hi"))) @@ -408,6 +415,17 @@ class SamplingTest { (decoded.content[0] as TextContent).text shouldBe "hi" } + @Test + fun `CreateMessageResult rejects empty content`() { + assertFailsWith { + CreateMessageResult( + role = Role.Assistant, + content = emptyList(), + model = "test-model", + ) + } + } + // ============================================================================ // SamplingContentSerializer (single-or-array wire heuristic) // ============================================================================ diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingValidation.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingValidation.kt index ef408080..82431713 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingValidation.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingValidation.kt @@ -20,8 +20,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent * 3. If the previous message contains `tool_use` blocks, the last message's * `tool_result` ids MUST form exactly the same set. * - * On the first violation throws [IllegalArgumentException]. No-op when there are fewer - * than two messages or no tool_use / tool_result blocks are involved. + * On the first violation throws [IllegalArgumentException]. No-op when no + * tool_use / tool_result blocks are involved. */ internal fun validateSamplingMessages(messages: List) { if (messages.isEmpty()) return @@ -44,6 +44,9 @@ internal fun validateSamplingMessages(messages: List) { if (hasPreviousToolUse) { val toolUseIds = previous.filterIsInstance().map { it.id }.toSet() val toolResultIds = last.filterIsInstance().map { it.toolUseId }.toSet() + require(toolResultIds.isNotEmpty()) { + "tool_use blocks from previous message must be followed by matching tool_result blocks" + } require(toolUseIds == toolResultIds) { "ids of tool_result blocks and tool_use blocks from previous message do not match" } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt index 2b5446bf..cfcbb62c 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt @@ -8,6 +8,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent import kotlinx.serialization.json.JsonObject import org.junit.jupiter.api.assertDoesNotThrow import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.test.assertFailsWith /** @@ -83,4 +84,21 @@ class SamplingTest { ) } } + + @Test + fun `validate tool_use requires explicit tool_result in last message`() { + val error = assertFailsWith { + validateSamplingMessages( + listOf( + SamplingMessage(Role.Assistant, toolUse("c1")), + SamplingMessage(Role.User, TextContent("missing result")), + ), + ) + } + + assertEquals( + "tool_use blocks from previous message must be followed by matching tool_result blocks", + error.message, + ) + } }