Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,13 @@ fun Server.registerConformanceTools() {
),
),
)
CallToolResult(listOf(TextContent(result.content.joinToString("\n") { it.toString() })))
val sampledText = result.content.joinToString("\n") { content ->
when (content) {
is TextContent -> content.text
else -> "Non-text sampling content: ${content::class.simpleName}"
}
}
CallToolResult(listOf(TextContent(sampledText)))
}

// 9. Elicitation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ public data class SamplingMessage(
@SerialName("_meta")
val meta: JsonObject? = null,
) {
init {
require(content.isNotEmpty()) { "content must contain at least one block" }
}

/**
* Convenience constructor for a single-block message. Wraps [content] in a
* singleton list so call sites can write `SamplingMessage(Role.User, TextContent("hi"))`
Expand Down Expand Up @@ -273,6 +277,10 @@ public data class CreateMessageResult(
@SerialName("_meta")
override val meta: JsonObject? = null,
) : ClientResult {
init {
require(content.isNotEmpty()) { "content must contain at least one block" }
}

/**
* Convenience constructor for a single-block response. Wraps [content] in a
* singleton list so call sites can write
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,13 @@ class SamplingTest {
(m.content[0] as TextContent).text shouldBe "hi"
}

@Test
fun `SamplingMessage rejects empty content`() {
assertFailsWith<IllegalArgumentException> {
SamplingMessage(role = Role.User, content = emptyList())
}
}

@Test
fun `SamplingMessage single-element content serialises as single object`() {
val m = SamplingMessage(role = Role.User, content = listOf(TextContent("hi")))
Expand Down Expand Up @@ -408,6 +415,17 @@ class SamplingTest {
(decoded.content[0] as TextContent).text shouldBe "hi"
}

@Test
fun `CreateMessageResult rejects empty content`() {
assertFailsWith<IllegalArgumentException> {
CreateMessageResult(
role = Role.Assistant,
content = emptyList(),
model = "test-model",
)
}
}

// ============================================================================
// SamplingContentSerializer (single-or-array wire heuristic)
// ============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent
* 3. If the previous message contains `tool_use` blocks, the last message's
* `tool_result` ids MUST form exactly the same set.
*
* On the first violation throws [IllegalArgumentException]. No-op when there are fewer
* than two messages or no tool_use / tool_result blocks are involved.
* On the first violation throws [IllegalArgumentException]. No-op when no
* tool_use / tool_result blocks are involved.
*/
internal fun validateSamplingMessages(messages: List<SamplingMessage>) {
if (messages.isEmpty()) return
Expand All @@ -44,6 +44,9 @@ internal fun validateSamplingMessages(messages: List<SamplingMessage>) {
if (hasPreviousToolUse) {
val toolUseIds = previous.filterIsInstance<ToolUseContent>().map { it.id }.toSet()
val toolResultIds = last.filterIsInstance<ToolResultContent>().map { it.toolUseId }.toSet()
require(toolResultIds.isNotEmpty()) {
"tool_use blocks from previous message must be followed by matching tool_result blocks"
}
require(toolUseIds == toolResultIds) {
"ids of tool_result blocks and tool_use blocks from previous message do not match"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent
import kotlinx.serialization.json.JsonObject
import org.junit.jupiter.api.assertDoesNotThrow
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith

/**
Expand Down Expand Up @@ -83,4 +84,21 @@ class SamplingTest {
)
}
}

@Test
fun `validate tool_use requires explicit tool_result in last message`() {
val error = assertFailsWith<IllegalArgumentException> {
validateSamplingMessages(
listOf(
SamplingMessage(Role.Assistant, toolUse("c1")),
SamplingMessage(Role.User, TextContent("missing result")),
),
)
}

assertEquals(
"tool_use blocks from previous message must be followed by matching tool_result blocks",
error.message,
)
}
}