From 3fc154e3ed1ee5ca8b852643fb67506905acc82d Mon Sep 17 00:00:00 2001 From: jiwon Date: Fri, 1 May 2026 04:12:52 +0900 Subject: [PATCH] feat(server): add onSessionClose hook to ClientConnection (#553) --- .../kotlin/sdk/server/ClientConnectionTest.kt | 35 +++++++++++++++++++ kotlin-sdk-server/api/kotlin-sdk-server.api | 1 + .../kotlin/sdk/server/ClientConnection.kt | 13 +++++++ 3 files changed, 49 insertions(+) diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnectionTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnectionTest.kt index ee8953a7..69c4eaac 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnectionTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnectionTest.kt @@ -31,8 +31,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.ToolListChangedNotification import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.JsonPrimitive import org.junit.jupiter.api.Test +import kotlin.time.Duration.Companion.seconds class ClientConnectionTest : AbstractServerFeaturesTest() { @@ -173,4 +175,37 @@ class ClientConnectionTest : AbstractServerFeaturesTest() { cap.assertAll() } + + @Test + fun `onSessionClose callback runs when the session closes`() = runTest { + val cleanupRan = CompletableDeferred() + addTool("test") { onSessionClose { cleanupRan.complete(Unit) } } + + client.callTool(CallToolRequest(CallToolRequestParams("test"))) + client.close() + + withClue("onSessionClose callback should fire when the session closes") { + withTimeout(1.seconds) { cleanupRan.await() } + } + } + + @Test + fun `multiple onSessionClose callbacks run in registration order`() = runTest { + val invocations = mutableListOf() + val allRan = CompletableDeferred() + addTool("test") { + onSessionClose { invocations += 1 } + onSessionClose { invocations += 2 } + onSessionClose { + invocations += 3 + allRan.complete(Unit) + } + } + + client.callTool(CallToolRequest(CallToolRequestParams("test"))) + client.close() + + withTimeout(1.seconds) { allRan.await() } + invocations shouldBe listOf(1, 2, 3) + } } diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 80da5b53..a2a2a5f7 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -12,6 +12,7 @@ public abstract interface class io/modelcontextprotocol/kotlin/sdk/server/Client public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/ClientConnection;Lio/modelcontextprotocol/kotlin/sdk/types/ListRootsRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public abstract fun notification (Lio/modelcontextprotocol/kotlin/sdk/types/ServerNotification;Lio/modelcontextprotocol/kotlin/sdk/types/RequestId;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun notification$default (Lio/modelcontextprotocol/kotlin/sdk/server/ClientConnection;Lio/modelcontextprotocol/kotlin/sdk/types/ServerNotification;Lio/modelcontextprotocol/kotlin/sdk/types/RequestId;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public abstract fun onSessionClose (Lkotlin/jvm/functions/Function0;)V public abstract fun ping (Lio/modelcontextprotocol/kotlin/sdk/types/PingRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun ping$default (Lio/modelcontextprotocol/kotlin/sdk/server/ClientConnection;Lio/modelcontextprotocol/kotlin/sdk/types/PingRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public abstract fun sendElicitationComplete (Lio/modelcontextprotocol/kotlin/sdk/types/ElicitationCompleteNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnection.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnection.kt index 9b4b2f35..eff53ffd 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnection.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnection.kt @@ -173,6 +173,15 @@ public interface ClientConnection { * @param notification Details of the completed elicitation. */ public suspend fun sendElicitationComplete(notification: ElicitationCompleteNotification) + + /** + * Registers a callback to be invoked when the underlying server session is closing. + * + * Use this to release session-scoped resources (e.g. database connections, temporary files, + * background jobs) created by tool, prompt, or resource handlers. Multiple callbacks may be + * registered and are invoked in registration order. + */ + public fun onSessionClose(block: () -> Unit) } internal class ClientConnectionImpl(private val session: ServerSession) : ClientConnection { @@ -303,6 +312,10 @@ internal class ClientConnectionImpl(private val session: ServerSession) : Client notification(notification) } + override fun onSessionClose(block: () -> Unit) { + session.onClose(block) + } + /** * Determines whether a message with the specified logging level is accepted * based on the current logging level of the session.