diff --git a/.gitignore b/.gitignore index d9f0fdec..805579c5 100644 --- a/.gitignore +++ b/.gitignore @@ -56,7 +56,7 @@ node_modules dist ### SWE agents ### -.claude/settings.local.json +.claude/ .junie/ ### Conformance test results ### diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index f400e80d..3acaa648 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -219,11 +219,22 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/SseServerTransport public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V + public fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lkotlinx/io/Source;Lkotlinx/io/Sink;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun send (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lio/modelcontextprotocol/kotlin/sdk/shared/TransportSendOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport$Builder { + public final fun getHandlerDispatcher ()Lkotlinx/coroutines/CoroutineDispatcher; + public final fun getIoDispatcher ()Lkotlinx/coroutines/CoroutineDispatcher; + public final fun getScope ()Lkotlinx/coroutines/CoroutineScope; + public final fun setHandlerDispatcher (Lkotlinx/coroutines/CoroutineDispatcher;)V + public final fun setIoDispatcher (Lkotlinx/coroutines/CoroutineDispatcher;)V + public final fun setScope (Lkotlinx/coroutines/CoroutineScope;)V +} + public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public static final field STANDALONE_SSE_STREAM_ID Ljava/lang/String; public fun ()V diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index a4d588cc..ea4e0e11 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -7,13 +7,22 @@ import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.McpDsl +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.RPCError import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ClosedSendChannelException +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.isActive import kotlinx.coroutines.launch import kotlinx.coroutines.withContext @@ -23,191 +32,338 @@ import kotlinx.io.Source import kotlinx.io.buffered import kotlinx.io.readByteArray import kotlinx.io.writeString -import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.AtomicReference import kotlin.concurrent.atomics.ExperimentalAtomicApi -import kotlin.coroutines.CoroutineContext private const val READ_BUFFER_SIZE = 8192L /** - * A server transport that communicates with a client via standard I/O. + * Server-side MCP transport that exchanges JSON-RPC messages over arbitrary byte streams. * - * Reads from input [Source] and writes to output [Sink]. + * Reads framed messages from the supplied [Source] and writes framed messages to the supplied + * [Sink]. Three internal coroutines drive the pipeline: * - * @constructor Creates a new instance of [StdioServerTransport]. - * @param inputStream The input [Source] used to receive data. - * @param outputStream The output [Sink] used to send data. + * - **reader** — pulls bytes from the input source into the parsing buffer; runs on + * [Builder.ioDispatcher]. + * - **processor** — parses messages out of the buffer and invokes the registered message handler; + * runs on [Builder.handlerDispatcher] (defaults to [Dispatchers.Default]) so blocking handler + * code does not starve the I/O pool. + * - **writer** — serialises outbound messages and flushes them to the output sink; runs on + * [Builder.ioDispatcher]. + * + * Both internal channels are bounded, so a slow handler or slow output naturally back-pressures + * the upstream producer — [send] suspends when the outbound channel is full. + * + * Both explicit [close] and a natural EOF from the input perform a graceful drain: in-flight + * outbound messages are flushed before `onClose` fires, and the input source and output sink are + * released. + * + * Example: + * ```kotlin + * val transport = StdioServerTransport( + * input = System.`in`.asSource().buffered(), + * output = System.out.asSink().buffered(), + * ) { + * scope = myScope + * } + * transport.start() + * ``` */ @OptIn(ExperimentalAtomicApi::class) -public class StdioServerTransport(private val inputStream: Source, outputStream: Sink) : AbstractTransport() { +public class StdioServerTransport private constructor( + private val input: Source, + output: Sink, + private val scope: CoroutineScope?, + private val handlerDispatcher: CoroutineDispatcher, + private val ioDispatcher: CoroutineDispatcher, +) : AbstractTransport() { + + /** + * Creates a [StdioServerTransport] reading from [input] and writing to [output], with + * optional configuration applied through a [Builder] block. + * + * @param input source the transport reads JSON-RPC messages from + * @param output sink the transport writes JSON-RPC messages to + * @param block configuration applied to the underlying [Builder] + */ + public constructor( + input: Source, + output: Sink, + block: Builder.() -> Unit = {}, + ) : this(Builder(input, output).apply(block)) + + private constructor(builder: Builder) : this( + input = builder.input, + output = builder.output, + scope = builder.scope, + handlerDispatcher = builder.handlerDispatcher, + ioDispatcher = builder.ioDispatcher, + ) + + /** + * Creates a [StdioServerTransport] from the given input source and output sink with default + * configuration. Retained for binary compatibility; prefer the [Builder]-based constructor. + * + * @param inputStream source the transport reads JSON-RPC messages from + * @param outputStream sink the transport writes JSON-RPC messages to + */ + @Deprecated( + message = "Use StdioServerTransport(input, output) { ... } instead.", + replaceWith = ReplaceWith("StdioServerTransport(input = inputStream, output = outputStream)"), + level = DeprecationLevel.WARNING, + ) + public constructor(inputStream: Source, outputStream: Sink) : this( + input = inputStream, + output = outputStream, + scope = null, + handlerDispatcher = Dispatchers.Default, + ioDispatcher = IODispatcher.limitedParallelism(2), + ) private val logger = KotlinLogging.logger {} private val readBuffer = ReadBuffer() - private val initialized: AtomicBoolean = AtomicBoolean(false) - private var readingJob: Job? = null - private var sendingJob: Job? = null - private var processingJob: Job? = null + private val readChannel = Channel(Channel.BUFFERED) + private val writeChannel = Channel(Channel.BUFFERED) + private val outputSink = output.buffered() + + private enum class State { New, Operational, Stopped } - private val coroutineContext: CoroutineContext = IODispatcher + SupervisorJob() - private val scope = CoroutineScope(coroutineContext) - private val readChannel = Channel(Channel.UNLIMITED) - private val writeChannel = Channel(Channel.UNLIMITED) - private val outputSink = outputStream.buffered() + private val state: AtomicReference = AtomicReference(State.New) + private var readerJob: Job? = null + private var processorJob: Job? = null + private var writerJob: Job? = null + + private var effectiveScope: CoroutineScope? = null + private val ownsScope: Boolean = (scope == null) + + private val setupComplete = CompletableDeferred() + + /** + * Starts the reader, processor, and writer coroutines. Must be called exactly once before + * messages can be exchanged; subsequent calls throw. + */ override suspend fun start() { - if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { + if (!state.compareAndSet(State.New, State.Operational)) { error("StdioServerTransport already started!") } - // Launch a coroutine to read from stdin - readingJob = launchReadingJob() + try { + val resolvedScope = scope ?: CoroutineScope( + currentCoroutineContext() + IODispatcher + SupervisorJob(), + ) + effectiveScope = resolvedScope - // Launch a coroutine to process messages from readChannel - processingJob = launchProcessingJob() - - // Launch a coroutine to handle message sending - sendingJob = launchSendingJob() + readerJob = resolvedScope.launch( + ioDispatcher + CoroutineName("StdioServerTransport.reader"), + ) { readerPump() } + processorJob = resolvedScope.launch( + handlerDispatcher + CoroutineName("StdioServerTransport.processor"), + start = CoroutineStart.UNDISPATCHED, + ) { processorPump() } + writerJob = resolvedScope.launch( + ioDispatcher + CoroutineName("StdioServerTransport.writer"), + start = CoroutineStart.UNDISPATCHED, + ) { writerPump() } + } finally { + setupComplete.complete(Unit) + } } - private fun launchReadingJob(): Job { - val job = scope.launch { - val buf = Buffer() - try { - while (isActive) { - val bytesRead = inputStream.readAtMostTo(buf, READ_BUFFER_SIZE) - if (bytesRead == -1L) { - // EOF reached - break - } - if (bytesRead > 0) { - val chunk = buf.readByteArray() - readChannel.send(chunk) - } + private suspend fun readerPump() { + val buf = Buffer() + try { + while (currentCoroutineContext().isActive) { + val bytesRead = input.readAtMostTo(buf, READ_BUFFER_SIZE) + if (bytesRead == -1L) break + if (bytesRead > 0) { + val chunk = buf.readByteArray() + readChannel.send(chunk) } - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { + } + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + if (state.load() == State.Stopped) { + logger.debug(e) { "Reader interrupted by close()" } + } else { logger.error(e) { "Error reading from stdin" } - _onError.invoke(e) - } finally { - // Reached EOF or error, close connection - close() + _onError(e) } + } finally { + readChannel.close() } - job.invokeOnCompletion { cause -> - logJobCompletion("Message reading", cause) - } - return job } - private fun launchProcessingJob(): Job { - val job = scope.launch { - try { - for (chunk in readChannel) { - readBuffer.append(chunk) - processReadBuffer() + private suspend fun processorPump() { + try { + for (chunk in readChannel) { + readBuffer.append(chunk) + while (true) { + val message = try { + readBuffer.readMessage() + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + _onError(e) + null + } ?: break + try { + _onMessage(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message" } + _onError(e) + } } - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - _onError.invoke(e) } + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + _onError(e) + } finally { + writeChannel.close() } - job.invokeOnCompletion { cause -> - logJobCompletion("Processing", cause) - } - return job } - private fun launchSendingJob(): Job { - val job = scope.launch { - try { - for (message in writeChannel) { - val json = serializeMessage(message) - outputSink.writeString(json) - outputSink.flush() - } - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error writing to stdout" } - _onError.invoke(e) - } - } - job.invokeOnCompletion { cause -> - logJobCompletion("Message sending", cause) - if (cause is CancellationException) { - readingJob?.cancel(cause) + private suspend fun writerPump() { + try { + for (message in writeChannel) { + val json = serializeMessage(message) + outputSink.writeString(json) + outputSink.flush() } + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error writing to stdout" } + _onError(e) + } finally { + transitionToStoppedNaturally() } - return job } - private suspend fun processReadBuffer() { - while (true) { - val message = try { - readBuffer.readMessage() - } catch (e: Throwable) { - _onError.invoke(e) - null - } + /** + * Closes the transport and waits for in-flight outbound messages to be flushed. Releases the + * input source and output sink, cancels the internal scope when the transport owns it, and + * invokes `onClose`. Safe to call multiple times and safe to race with [start]. + */ + override suspend fun close() { + var previous: State + do { + previous = state.load() + } while (previous != State.Stopped && !state.compareAndSet(previous, State.Stopped)) - if (message == null) break - // Async invocation broke delivery order - try { - _onMessage.invoke(message) - } catch (e: CancellationException) { - throw e - } catch (e: Throwable) { - logger.error(e) { "Error processing message" } - _onError.invoke(e) - } + if (previous == State.New) { + setupComplete.complete(Unit) + return } - } - private fun logJobCompletion(jobName: String, cause: Throwable?) { - when (cause) { - is CancellationException -> { - } + withContext(NonCancellable) { + setupComplete.await() - null -> { - logger.debug { "$jobName job completed" } + if (previous == State.Stopped) { + writerJob?.join() + return@withContext } - else -> { - logger.debug(cause) { "$jobName job completed exceptionally" } + runCatching { input.close() }.onFailure { logger.warn(it) { "Failed to close stdin" } } + readerJob?.cancel() + readChannel.close() + processorJob?.join() + writeChannel.close() + writerJob?.join() + runCatching { outputSink.close() } + .onFailure { logger.warn(it) { "Failed to close stdout" } } + readBuffer.clear() + if (ownsScope) { + effectiveScope?.coroutineContext?.get(Job)?.cancel() } + invokeOnCloseCallback() } } - override suspend fun close() { - if (!initialized.compareAndSet(expectedValue = true, newValue = false)) return - - withContext(NonCancellable) { - writeChannel.close() - sendingJob?.cancelAndJoin() - - runCatching { - inputStream.close() - }.onFailure { logger.warn(it) { "Failed to close stdin" } } - - readingJob?.cancel() - readChannel.close() - - processingJob?.cancelAndJoin() + /** + * Queues [message] for the writer coroutine. Suspends when the outbound channel is full, + * applying back-pressure to the caller. Throws [McpException] with + * [RPCError.ErrorCode.CONNECTION_CLOSED] if the transport has not been started or has + * already closed. + * + * @param message JSON-RPC message to send + * @param options transport-specific send options; ignored by this transport + */ + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + when (state.load()) { + State.New -> throw McpException( + code = RPCError.ErrorCode.CONNECTION_CLOSED, + message = "Transport is not started", + ) - readBuffer.clear() + State.Stopped -> throw McpException( + code = RPCError.ErrorCode.CONNECTION_CLOSED, + message = "Transport is closed", + ) - runCatching { - outputSink.flush() - outputSink.close() - }.onFailure { logger.warn(it) { "Failed to close stdout" } } + State.Operational -> Unit + } + try { + writeChannel.send(message) + } catch (e: CancellationException) { + throw e + } catch (e: ClosedSendChannelException) { + throw McpException( + code = RPCError.ErrorCode.CONNECTION_CLOSED, + message = "Transport is closed", + cause = e, + ) + } + } - invokeOnCloseCallback() + private fun transitionToStoppedNaturally() { + if (!state.compareAndSet(State.Operational, State.Stopped)) return + runCatching { input.close() } + .onFailure { logger.warn(it) { "Failed to close stdin" } } + runCatching { outputSink.close() } + .onFailure { logger.warn(it) { "Failed to close stdout" } } + readBuffer.clear() + if (ownsScope) { + // Non-null in practice: writer launches after effectiveScope is published. + effectiveScope?.coroutineContext?.get(Job)?.cancel() } + invokeOnCloseCallback() } - override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { - writeChannel.send(message) + /** + * Configuration builder for [StdioServerTransport]. Used via the + * `StdioServerTransport(input, output) { ... }` factory; the I/O endpoints are supplied + * positionally, while [scope], [handlerDispatcher], and [ioDispatcher] are configurable + * inside the block. + * + * Example: + * ```kotlin + * StdioServerTransport(input, output) { + * scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + * handlerDispatcher = myHandlerDispatcher + * ioDispatcher = Dispatchers.IO + * } + * ``` + * + * @property scope Optional caller-supplied [CoroutineScope]. When non-null, the transport's + * pipeline coroutines run as children of this scope and the transport does not tear it + * down on [close]. When `null`, the transport creates and owns an internal scope. + * @property handlerDispatcher Dispatcher for invoking the registered message handler. + * Defaults to [Dispatchers.Default]. + * @property ioDispatcher Dispatcher for the reader and writer coroutines. Must allow at + * least two threads to run concurrently so the reader and writer don't block each other. + * The default is a two-thread view of the platform I/O dispatcher + * (`IODispatcher.limitedParallelism(2)`); pass a different value to share or isolate I/O + * threads with the rest of your application. + */ + @McpDsl + public class Builder internal constructor(internal val input: Source, internal val output: Sink) { + public var scope: CoroutineScope? = null + public var handlerDispatcher: CoroutineDispatcher = Dispatchers.Default + public var ioDispatcher: CoroutineDispatcher = IODispatcher.limitedParallelism(2) } } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index 32123a69..09e001b8 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -5,17 +5,28 @@ import io.kotest.assertions.throwables.shouldThrow import io.kotest.assertions.withClue import io.kotest.matchers.collections.shouldContain import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.sdk.types.PingRequest +import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.toJSON import io.modelcontextprotocol.kotlin.test.utils.runIntegrationTest import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.cancel +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch import kotlinx.coroutines.withTimeout +import kotlinx.coroutines.withTimeoutOrNull import kotlinx.io.Buffer import kotlinx.io.RawSink import kotlinx.io.RawSource @@ -33,9 +44,12 @@ import java.io.ByteArrayOutputStream import java.io.IOException import java.io.PipedInputStream import java.io.PipedOutputStream +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors import kotlin.test.assertFalse import kotlin.test.assertTrue import kotlin.test.fail +import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -71,13 +85,13 @@ class StdioServerTransportTest { @Test fun `should be safe to close before start`() = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) server.close() // initialized guard makes this a no-op; must not throw } @Test fun `should start then close cleanly`() = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) server.onError { error -> throw error } @@ -96,7 +110,7 @@ class StdioServerTransportTest { @Test fun `should not read until started`() = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) server.onError { error -> throw error } @@ -125,7 +139,7 @@ class StdioServerTransportTest { @Test fun `should read multiple messages`() = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) server.onError { error -> throw error } @@ -162,7 +176,7 @@ class StdioServerTransportTest { @ParameterizedTest(name = "[{index}] input throws {0}") @MethodSource("inputErrors") fun `should invoke onError when input stream throws`(throwable: Throwable): Unit = runIntegrationTest { - val server = StdioServerTransport(FaultyRawSource(throwable).buffered(), printOutput) + val server = StdioServerTransport(input = FaultyRawSource(throwable).buffered(), output = printOutput) val capturedError = CompletableDeferred() server.onError { capturedError.complete(it) } server.onMessage {} @@ -176,7 +190,7 @@ class StdioServerTransportTest { @ParameterizedTest(name = "[{index}] output throws {0}") @MethodSource("outputErrors") fun `should invoke onError when output sink throws`(throwable: Throwable): Unit = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, FaultyRawSink(throwable).buffered()) + val server = StdioServerTransport(input = bufferedInput, output = FaultyRawSink(throwable).buffered()) val capturedError = CompletableDeferred() server.onError { capturedError.complete(it) } server.onMessage {} @@ -190,7 +204,7 @@ class StdioServerTransportTest { @Test fun `should call onClose when input EOF is reached`(): Unit = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) val didClose = CompletableDeferred() server.onError { throw it } server.onClose { didClose.complete(Unit) } @@ -206,7 +220,7 @@ class StdioServerTransportTest { @Test fun `should throw when starting twice`(): Unit = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) server.onMessage {} server.start() withClue("Server should not start twice") { @@ -220,7 +234,7 @@ class StdioServerTransportTest { @ParameterizedTest(name = "[{index}] handler throws {0}") @MethodSource("handlerErrors") fun `should continue processing messages after handler throws`(throwable: Throwable) = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) val capturedErrors = mutableListOf() val receivedMessages = mutableListOf() val secondMessageProcessed = CompletableDeferred() @@ -253,7 +267,7 @@ class StdioServerTransportTest { @Test fun `should not invoke onError for CancellationException in handler`() = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) val capturedError = CompletableDeferred() server.onError { capturedError.complete(it) } @@ -281,7 +295,7 @@ class StdioServerTransportTest { @Test fun `should continue receiving valid messages after malformed JSON is skipped`() = runIntegrationTest { - val server = StdioServerTransport(bufferedInput, printOutput) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) val received = CompletableDeferred() // ReadBuffer silently skips unparseable lines — no onError callback expected server.onError {} @@ -299,6 +313,264 @@ class StdioServerTransportTest { server.close() } + @Test + fun `should throw McpException when send is called before start`() = runIntegrationTest { + val server = StdioServerTransport(input = bufferedInput, output = printOutput) + val ex = shouldThrow { + server.send(PingRequest().toJSON()) + } + ex.code shouldBe RPCError.ErrorCode.CONNECTION_CLOSED + } + + @Test + fun `should throw McpException when send is called after close`() = runIntegrationTest { + val server = StdioServerTransport(input = bufferedInput, output = printOutput) + server.onError {} + server.onMessage {} + server.start() + server.close() + + val ex = shouldThrow { + server.send(PingRequest().toJSON()) + } + ex.code shouldBe RPCError.ErrorCode.CONNECTION_CLOSED + } + + @Test + fun `should drain in-flight outbound messages on graceful close`() = runIntegrationTest { + val server = StdioServerTransport(input = bufferedInput, output = printOutput) + server.onError {} + server.onMessage {} + server.start() + + val numMessages = 50 + repeat(numMessages) { + server.send(PingRequest().toJSON()) + } + server.close() + + val outputLines = String(output.toByteArray()).lines().count { it.isNotBlank() } + outputLines shouldBe numMessages + } + + @Test + fun `should fail-fast when reader throws non-EOF IOException`() = runIntegrationTest { + val ioError = IOException("transient stream failure") + val server = StdioServerTransport(input = FaultyRawSource(ioError).buffered(), output = printOutput) + val errorCaptured = CompletableDeferred() + val closeCaptured = CompletableDeferred() + server.onError { errorCaptured.complete(it) } + server.onClose { closeCaptured.complete(Unit) } + server.onMessage {} + + server.start() + + errorCaptured.await() shouldBe ioError + closeCaptured.await() + } + + @Test + fun `should close input source on natural EOF`() = runIntegrationTest { + val inputClosed = CompletableDeferred() + val eofSource = object : RawSource { + override fun readAtMostTo(sink: Buffer, byteCount: Long): Long = -1L + override fun close() { + inputClosed.complete(Unit) + } + } + val server = StdioServerTransport(input = eofSource.buffered(), output = printOutput) + val onCloseFired = CompletableDeferred() + server.onError {} + server.onMessage {} + server.onClose { onCloseFired.complete(Unit) } + + server.start() + + inputClosed.await() + onCloseFired.await() + } + + // region: concurrency + + @Test + fun `should be safe under concurrent start and close`() = runIntegrationTest(timeout = 30.seconds) { + val iterations = 2000 + repeat(iterations) { iteration -> + val server = StdioServerTransport(input = Buffer(), output = Buffer()) + server.onError {} + server.onMessage {} + + coroutineScope { + launch(Dispatchers.Default) { + try { + server.start() + } catch (_: IllegalStateException) { + } + } + launch(Dispatchers.Default) { + server.close() + } + } + + withClue("iteration $iteration: send() after race should fail with CONNECTION_CLOSED") { + val ex = shouldThrow { + server.send(PingRequest().toJSON()) + } + ex.code shouldBe RPCError.ErrorCode.CONNECTION_CLOSED + } + } + } + + @Test + fun `should suspend send under back-pressure`() = runIntegrationTest(timeout = 30.seconds) { + val unblock = CountDownLatch(1) + val blockingSink = object : RawSink { + override fun write(source: Buffer, byteCount: Long) { + source.clear() + unblock.await() + } + + override fun flush() { + // noop + } + + override fun close() { + unblock.countDown() + } + } + val server = StdioServerTransport(input = bufferedInput, output = blockingSink.buffered()) + server.onError {} + server.onMessage {} + server.start() + + val maxAttempts = 256 + var suspendedAt = -1 + for (i in 0 until maxAttempts) { + val ok = withTimeoutOrNull(100.milliseconds) { server.send(PingRequest().toJSON()) } + if (ok == null) { + suspendedAt = i + break + } + } + assertTrue( + suspendedAt in 0 until maxAttempts, + "send() should have suspended within $maxAttempts attempts, got suspendedAt=$suspendedAt", + ) + // Unblock writer so close() can drain. + unblock.countDown() + server.close() + } + + // endregion + + // region: scope and dispatcher knobs + + @Test + fun `should honor externally provided CoroutineScope`() = runIntegrationTest { + val externalScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + val server = StdioServerTransport(input = bufferedInput, output = printOutput) { + scope = externalScope + } + val closeCaptured = CompletableDeferred() + server.onError {} + server.onClose { closeCaptured.complete(Unit) } + server.onMessage {} + + server.start() + externalScope.cancel() + + closeCaptured.await() + + val ex = shouldThrow { + server.send(PingRequest().toJSON()) + } + ex.code shouldBe RPCError.ErrorCode.CONNECTION_CLOSED + } + + @Test + fun `should use ioDispatcher for reader and writer`() = runIntegrationTest { + val threadName = "stdio-test-io-thread" + val executor = Executors.newFixedThreadPool(2) { r -> Thread(r, threadName) } + try { + val readerThread = CompletableDeferred() + val writerThread = CompletableDeferred() + val unblockReader = CountDownLatch(1) + val recordingInput = object : RawSource { + override fun readAtMostTo(sink: Buffer, byteCount: Long): Long { + if (!readerThread.isCompleted) { + readerThread.complete(Thread.currentThread().name) + } + unblockReader.await() + return -1L + } + + override fun close() { + unblockReader.countDown() + } + } + val recordingOutput = object : RawSink { + override fun write(source: Buffer, byteCount: Long) { + if (!writerThread.isCompleted) { + writerThread.complete(Thread.currentThread().name) + } + source.clear() + } + + override fun flush() { + // noop + } + + override fun close() { + // noop + } + } + val server = StdioServerTransport( + input = recordingInput.buffered(), + output = recordingOutput.buffered(), + ) { + ioDispatcher = executor.asCoroutineDispatcher() + } + server.onError {} + server.onMessage {} + server.start() + + readerThread.await() shouldContain threadName + server.send(PingRequest().toJSON()) + writerThread.await() shouldContain threadName + + server.close() + } finally { + executor.shutdownNow() + } + } + + @Test + fun `should use handlerDispatcher for message handling`() = runIntegrationTest { + val threadName = "stdio-test-handler-thread" + val executor = Executors.newSingleThreadExecutor { r -> Thread(r, threadName) } + try { + val server = StdioServerTransport(input = bufferedInput, output = printOutput) { + handlerDispatcher = executor.asCoroutineDispatcher() + } + val observedThreadName = CompletableDeferred() + server.onError {} + server.onMessage { + observedThreadName.complete(Thread.currentThread().name) + } + server.start() + + inputWriter.write(serializeMessage(PingRequest().toJSON())) + inputWriter.flush() + + observedThreadName.await() shouldContain threadName + server.close() + } finally { + executor.shutdownNow() + } + } + + // endregion + private fun inputErrors() = listOf( IOException("simulated read failure"), RuntimeException("unexpected read exception"),