diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt index d14ae88e..798a6d9b 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt @@ -6,7 +6,13 @@ import io.kotest.matchers.string.shouldContain import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsResult import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.Prompt import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage import io.modelcontextprotocol.kotlin.sdk.types.Role @@ -19,6 +25,7 @@ import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -697,4 +704,56 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { exception.message shouldBe expectedMessage } } + + @Test + fun testListPromptsPagination() = runBlocking(Dispatchers.IO) { + val pagePrefix = "paginated-prompt-" + (0 until 5).forEach { i -> + val name = "$pagePrefix$i" + server.addPrompt(name = name, description = "desc", arguments = listOf()) { _ -> + GetPromptResult(description = "desc", messages = listOf(PromptMessage(role = Role.Assistant, content = TextContent(text = name)))) + } + } + + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.PromptsList) { request, _ -> + val all = server.prompts.values.map { it.prompt } + val cursor = request.cursor?.toIntOrNull() ?: 0 + val pageSize = 2 + val page = all.drop(cursor).take(pageSize) + val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null + ListPromptsResult(prompts = page, nextCursor = next) + } + } + + val allPrompts = mutableListOf() + var currentCursor: String? = null + do { + val request = if (currentCursor == null) ListPromptsRequest() else ListPromptsRequest(PaginatedRequestParams(cursor = currentCursor)) + val response = client.listPrompts(request) + allPrompts.addAll(response.prompts) + currentCursor = response.nextCursor + } while (currentCursor != null) + + val paginatedPrompts = allPrompts.filter { it.name.startsWith(pagePrefix) } + assertEquals(5, paginatedPrompts.size, "Should have collected all 5 paginated prompts") + } + + @Test + fun testListPromptsInvalidCursor() = runBlocking(Dispatchers.IO) { + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.PromptsList) { request, _ -> + val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor") + val all = server.prompts.values.map { it.prompt } + val page = all.drop(cursor).take(2) + ListPromptsResult(prompts = page, nextCursor = null) + } + } + + val exception = assertFailsWith { + client.listPrompts(ListPromptsRequest(PaginatedRequestParams(cursor = "not-a-number"))) + } + + assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code) + } } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt index 4fc53a52..df195649 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -1,7 +1,11 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin import io.modelcontextprotocol.kotlin.sdk.types.BlobResourceContents +import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams @@ -20,6 +24,7 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import java.util.concurrent.atomic.AtomicBoolean import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -309,4 +314,62 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() { assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty") } } + + @Test + fun testListResourcesPagination() = runBlocking(Dispatchers.IO) { + val prefix = "paginated-resource-" + (0 until 6).forEach { i -> + val uri = "test://$prefix$i.txt" + server.addResource(uri = uri, name = "Name-$i", description = "desc", mimeType = "text/plain") { request -> + ReadResourceResult(contents = listOf(TextResourceContents(text = uri, uri = request.params.uri, mimeType = "text/plain"))) + } + } + + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.ResourcesList) { request, _ -> + val all = server.resources.values.map { it.resource } + val cursor = request.cursor?.toIntOrNull() ?: 0 + val pageSize = 3 + val page = all.drop(cursor).take(pageSize) + val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null + ListResourcesResult(resources = page, nextCursor = next) + } + } + + val combinedUris = mutableListOf() + var currentCursor: String? = null + + do { + val request = if (currentCursor == null) { + ListResourcesRequest() + } else { + ListResourcesRequest(PaginatedRequestParams(cursor = currentCursor)) + } + + val response = client.listResources(request) + combinedUris += response.resources.map { it.uri } + currentCursor = response.nextCursor + } while (currentCursor != null) + + val paginatedResources = combinedUris.filter { it.contains(prefix) } + assertEquals(6, paginatedResources.size, "Should have collected all 6 paginated resources") + } + + @Test + fun testListResourcesInvalidCursor() = runBlocking(Dispatchers.IO) { + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.ResourcesList) { request, _ -> + val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor") + val all = server.resources.values.map { it.resource } + val page = all.drop(cursor).take(2) + ListResourcesResult(resources = page, nextCursor = null) + } + } + + val exception = assertFailsWith { + client.listResources(ListResourcesRequest(PaginatedRequestParams(cursor = "bad"))) + } + + assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code) + } } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt index 7da82cc3..be3a8822 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt @@ -6,6 +6,12 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult import io.modelcontextprotocol.kotlin.sdk.types.ContentBlock import io.modelcontextprotocol.kotlin.sdk.types.ImageContent +import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.TextContent import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema @@ -25,6 +31,7 @@ import java.text.DecimalFormat import java.text.DecimalFormatSymbols import java.util.Locale import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -791,4 +798,62 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() { "Error message should indicate the tool was not found", ) } + + @Test + fun testListToolsPagination() = runBlocking(Dispatchers.IO) { + val prefix = "paginated-tool-" + (0 until 5).forEach { i -> + val name = "$prefix$i" + server.addTool(name = name, description = "desc") { request -> + CallToolResult(content = listOf(TextContent(text = name)), structuredContent = buildJsonObject { put("name", name) }) + } + } + + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.ToolsList) { request, _ -> + val all = server.tools.values.map { it.tool } + val cursor = request.cursor?.toIntOrNull() ?: 0 + val pageSize = 2 + val page = all.drop(cursor).take(pageSize) + val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null + ListToolsResult(tools = page, nextCursor = next) + } + } + + val combinedNames = mutableListOf() + var currentCursor: String? = null + + do { + val request = if (currentCursor == null) { + ListToolsRequest() + } else { + ListToolsRequest(PaginatedRequestParams(cursor = currentCursor)) + } + + val response = client.listTools(request) + combinedNames += response.tools.map { it.name } + currentCursor = response.nextCursor + } while (currentCursor != null) + + val paginatedTools = combinedNames.filter { it.startsWith(prefix) } + assertEquals(5, paginatedTools.size, "Should have collected all 5 paginated tools") + } + + @Test + fun testListToolsInvalidCursor() = runBlocking(Dispatchers.IO) { + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.ToolsList) { request, _ -> + val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor") + val all = server.tools.values.map { it.tool } + val page = all.drop(cursor).take(2) + ListToolsResult(tools = page) + } + } + + val exception = assertFailsWith { + client.listTools(ListToolsRequest(PaginatedRequestParams(cursor = "bad"))) + } + + assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code) + } }