From 2f1e1fbee71873a0cc9f1b7892bf2a0e0a9b6918 Mon Sep 17 00:00:00 2001 From: Kenny Root Date: Wed, 29 Apr 2026 07:30:56 -0700 Subject: [PATCH] fix: RekeyTest stability Use Kotlin StateFlow for rekey count so we do not need to use delay. Provide a way to pass a CoroutineDispatcher to the state machine so it can be on the same clock as the tests. This lets us advance time virtually in a test. --- sshlib/api.txt | 2 +- .../connectbot/sshlib/client/SshConnection.kt | 19 ++-- .../connectbot/sshlib/client/FakeSshServer.kt | 20 ++-- .../org/connectbot/sshlib/client/RekeyTest.kt | 91 ++++++++----------- 4 files changed, 62 insertions(+), 70 deletions(-) diff --git a/sshlib/api.txt b/sshlib/api.txt index f9f5288..dc2b3ab 100644 --- a/sshlib/api.txt +++ b/sshlib/api.txt @@ -712,7 +712,7 @@ package org.connectbot.sshlib.client { } public final class SshConnection { - ctor public SshConnection(org.connectbot.sshlib.transport.Transport transport, optional java.lang.String clientVersion, org.connectbot.sshlib.HostKeyVerifier hostKeyVerifier, optional java.lang.String kexAlgorithms, optional java.lang.String hostKeyAlgorithms, optional java.lang.String encryptionAlgorithms, optional java.lang.String macAlgorithms, optional java.lang.String compressionAlgorithms, optional boolean preferPasswordAuth, optional long rekeyIntervalMs, optional long rekeyBytesLimit, optional kotlin.coroutines.CoroutineContext coroutineContext); + ctor public SshConnection(org.connectbot.sshlib.transport.Transport transport, optional java.lang.String clientVersion, org.connectbot.sshlib.HostKeyVerifier hostKeyVerifier, optional java.lang.String kexAlgorithms, optional java.lang.String hostKeyAlgorithms, optional java.lang.String encryptionAlgorithms, optional java.lang.String macAlgorithms, optional java.lang.String compressionAlgorithms, optional boolean preferPasswordAuth, optional long rekeyIntervalMs, optional long rekeyBytesLimit, optional kotlinx.coroutines.CoroutineDispatcher coroutineDispatcher); method public suspend java.lang.Object? authenticateKeyboardInteractive(java.lang.String username, org.connectbot.sshlib.KeyboardInteractiveCallback callback, kotlin.coroutines.Continuation); method public suspend java.lang.Object? authenticatePassword(java.lang.String username, java.lang.String password, kotlin.coroutines.Continuation); method public suspend java.lang.Object? close(kotlin.coroutines.Continuation); diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt index af69be9..bd36f96 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt @@ -19,7 +19,9 @@ package org.connectbot.sshlib.client import io.kaitai.struct.ByteBufferKaitaiStream import io.kaitai.struct.KaitaiStruct import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CloseableCoroutineDispatcher import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.Dispatchers @@ -36,7 +38,6 @@ import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.flow.asSharedFlow import kotlinx.coroutines.isActive import kotlinx.coroutines.launch -import kotlinx.coroutines.newSingleThreadContext import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext @@ -140,7 +141,7 @@ import org.connectbot.sshlib.AuthResult as PublicAuthResult * @param transport Underlying transport (e.g., TCP socket) * @param clientVersion Client version string (default: SSH-2.0-CBSSH_1.0) */ -@OptIn(ExperimentalCoroutinesApi::class) +@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class) class SshConnection( private val transport: Transport, private val clientVersion: String = "SSH-2.0-CBSSH_1.0", @@ -153,20 +154,19 @@ class SshConnection( private val preferPasswordAuth: Boolean = false, private val rekeyIntervalMs: Long = 3_600_000L, private val rekeyBytesLimit: Long = 1_073_741_824L, - coroutineContext: CoroutineContext = Dispatchers.IO, + coroutineDispatcher: CoroutineDispatcher = Dispatchers.IO, ) { companion object { private val logger = LoggerFactory.getLogger(SshConnection::class.java) } + private val stateMachineDispatcher = coroutineDispatcher.limitedParallelism(1, "StateMachine") + private class HostKeyRejectedException(val key: PublicKey) : Exception("Host key rejected") private val packetIO = PacketIO(transport) - @OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class) - private val stateMachineDispatcher = newSingleThreadContext("ssh-state-machine") - private val callbacks = object : SshClientCallbacks { override fun sendVersion() = this@SshConnection.sendVersion() override fun receiveVersion(banner: IdBanner) = this@SshConnection.receiveVersion(banner) @@ -205,7 +205,7 @@ class SshConnection( } private val stateMachine = SshClientStateMachine(callbacks) - private val connectionScope = CoroutineScope(SupervisorJob() + coroutineContext) + private val connectionScope = CoroutineScope(SupervisorJob() + coroutineDispatcher) private val writeMutex = Mutex() private val _disconnectedFlow = MutableSharedFlow(extraBufferCapacity = 1) @@ -371,9 +371,7 @@ class SshConnection( // Start packet loop — handles all binary SSH packets from here startPacketLoop() - withTimeout(30_000L) { - deferred.await() - } + withTimeout(30_000L) { deferred.await() } } catch (e: TimeoutCancellationException) { ConnectResult.TransportError(Exception("Connection timed out")) } catch (e: HostKeyRejectedException) { @@ -1677,7 +1675,6 @@ class SshConnection( transport.close() packetLoopJob?.join() packetLoopJob = null - stateMachineDispatcher.close() sessionId?.fill(0) sessionId = null } diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt index c14d30f..5fea77a 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt @@ -19,6 +19,10 @@ package org.connectbot.sshlib.client import io.kaitai.struct.ByteBufferKaitaiStream import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import kotlinx.coroutines.selects.select import kotlinx.coroutines.sync.Mutex @@ -47,14 +51,16 @@ import java.security.KeyPair import java.security.KeyPairGenerator import java.security.MessageDigest import java.security.SecureRandom +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext class FakeSshServer( private val serverTransport: PipedTransport, private val scope: CoroutineScope, + private val coroutineContext: CoroutineContext = EmptyCoroutineContext, ) { - @Volatile - var rekeyCount = 0 - private set + private val _rekeyCount = MutableStateFlow(0) + val rekeyCount: StateFlow = _rekeyCount.asStateFlow() private val hostKeyPair: KeyPair = KeyPairGenerator.getInstance("Ed25519").generateKeyPair() private val hostKeyBlob: ByteArray = SshPublicKeyEncoder.encode(hostKeyPair.public, "ssh-ed25519") @@ -71,7 +77,7 @@ class FakeSshServer( private val rekeyRequestChannel = Channel(Channel.UNLIMITED) fun start() { - scope.launch { serve() } + scope.launch(coroutineContext) { serve() } } /** @@ -83,7 +89,7 @@ class FakeSshServer( } fun sendIgnore() { - scope.launch { + scope.launch(coroutineContext) { val msg = org.connectbot.sshlib.protocol.SshMsgIgnore().apply { setData(createByteString(byteArrayOf())) _check() @@ -185,7 +191,7 @@ class FakeSshServer( writeMutex.withLock { io.writePacket(SshEnums.MessageType.SSH_MSG_NEWKEYS.id().toInt()) } packets.receive() // client NEWKEYS activateEncryption(io) - rekeyCount++ + _rekeyCount.update { it + 1 } } private suspend fun doClientInitiatedKex( @@ -200,7 +206,7 @@ class FakeSshServer( writeMutex.withLock { io.writePacket(SshEnums.MessageType.SSH_MSG_NEWKEYS.id().toInt()) } packets.receive() // client NEWKEYS activateEncryption(io) - rekeyCount++ + _rekeyCount.update { it + 1 } } private suspend fun sendKexInit(io: PacketIO): ByteArray { diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/RekeyTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/RekeyTest.kt index b0d11d1..e35fa67 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/RekeyTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/RekeyTest.kt @@ -16,25 +16,23 @@ package org.connectbot.sshlib.client +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.cancel -import kotlinx.coroutines.delay import kotlinx.coroutines.flow.first +import kotlinx.coroutines.launch import kotlinx.coroutines.test.StandardTestDispatcher import kotlinx.coroutines.test.advanceTimeBy import kotlinx.coroutines.test.runTest -import kotlinx.coroutines.withContext -import kotlinx.coroutines.withTimeout import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.coroutines.yield import org.connectbot.sshlib.ConnectResult import org.connectbot.sshlib.HostKeyVerifier import org.connectbot.sshlib.PublicKey import org.connectbot.sshlib.transport.PipedTransport import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertNull -import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test import kotlin.test.assertIs @@ -45,11 +43,22 @@ class RekeyTest { override suspend fun verify(key: PublicKey): Boolean = true } + private suspend fun connectInBackground( + connection: SshConnection, + backgroundScope: CoroutineScope, + dispatcher: CoroutineDispatcher, + ): ConnectResult { + val result = CompletableDeferred() + backgroundScope.launch(dispatcher) { result.complete(connection.connect()) } + yield() + return result.await() + } + @Test fun `rekey triggered when byte limit exceeded`() = runTest { - val scope = CoroutineScope(Dispatchers.Default) + val dispatcher = StandardTestDispatcher(testScheduler) val (clientTransport, serverTransport) = PipedTransport.create() - val server = FakeSshServer(serverTransport, scope) + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) server.start() val connection = SshConnection( @@ -57,36 +66,30 @@ class RekeyTest { hostKeyVerifier = acceptAllVerifier, rekeyBytesLimit = 512L, rekeyIntervalMs = Long.MAX_VALUE, + coroutineDispatcher = dispatcher, ) try { - val result = withContext(Dispatchers.Default) { connection.connect() } + val result = connectInBackground(connection, backgroundScope, dispatcher) assertIs(result) server.sendIgnore() - withContext(Dispatchers.Default) { - withTimeout(5_000L) { - while (server.rekeyCount < 1) delay(50) - } - } - - assertTrue(server.rekeyCount >= 1) + server.rekeyCount.first { it >= 1 } // Verify connection is still alive after re-key server.sendIgnore() - withContext(Dispatchers.Default) { delay(200) } + val disconnectCause = withTimeoutOrNull(250) { connection.disconnectedFlow.first() } + assertNull(disconnectCause) } finally { connection.close() - scope.cancel() } } @Test fun `rekey triggered after interval elapses`() = runTest { - val scope = CoroutineScope(Dispatchers.Default) val dispatcher = StandardTestDispatcher(testScheduler) val (clientTransport, serverTransport) = PipedTransport.create() - val server = FakeSshServer(serverTransport, scope) + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) server.start() val connection = SshConnection( @@ -94,36 +97,31 @@ class RekeyTest { hostKeyVerifier = acceptAllVerifier, rekeyIntervalMs = 3_600_000L, rekeyBytesLimit = Long.MAX_VALUE, - coroutineContext = dispatcher, + coroutineDispatcher = dispatcher, ) try { - val result = withContext(Dispatchers.Default) { connection.connect() } + val result = connectInBackground(connection, backgroundScope, dispatcher) assertIs(result) - assertEquals(0, server.rekeyCount) + assertEquals(0, server.rekeyCount.value) advanceTimeBy(3_600_001L) server.sendIgnore() - withContext(Dispatchers.Default) { - withTimeout(5_000L) { - while (server.rekeyCount < 1) delay(50) - } - } + server.rekeyCount.first { it >= 1 } - assertEquals(1, server.rekeyCount) + assertEquals(1, server.rekeyCount.value) } finally { connection.close() - scope.cancel() } } @Test fun `server-initiated rekey is handled correctly`() = runTest { - val scope = CoroutineScope(Dispatchers.Default) + val dispatcher = StandardTestDispatcher(testScheduler) val (clientTransport, serverTransport) = PipedTransport.create() - val server = FakeSshServer(serverTransport, scope) + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) server.start() val connection = SshConnection( @@ -131,31 +129,27 @@ class RekeyTest { hostKeyVerifier = acceptAllVerifier, rekeyIntervalMs = Long.MAX_VALUE, rekeyBytesLimit = Long.MAX_VALUE, + coroutineDispatcher = dispatcher, ) try { - val result = withContext(Dispatchers.Default) { connection.connect() } + val result = connectInBackground(connection, backgroundScope, dispatcher) assertIs(result) server.initiateRekey() - withContext(Dispatchers.Default) { - withTimeout(5_000L) { - while (server.rekeyCount < 1) delay(50) - } - } + server.rekeyCount.first { it >= 1 } - assertEquals(1, server.rekeyCount) + assertEquals(1, server.rekeyCount.value) } finally { connection.close() - scope.cancel() } } @Test fun `packets received during rekey are not dropped`() = runTest { - val scope = CoroutineScope(Dispatchers.Default) + val dispatcher = StandardTestDispatcher(testScheduler) val (clientTransport, serverTransport) = PipedTransport.create() - val server = FakeSshServer(serverTransport, scope) + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) server.start() val connection = SshConnection( @@ -163,31 +157,26 @@ class RekeyTest { hostKeyVerifier = acceptAllVerifier, rekeyIntervalMs = Long.MAX_VALUE, rekeyBytesLimit = Long.MAX_VALUE, + coroutineDispatcher = dispatcher, ) try { - val result = withContext(Dispatchers.Default) { connection.connect() } + val result = connectInBackground(connection, backgroundScope, dispatcher) assertIs(result) server.sendIgnore() server.initiateRekey() server.sendIgnore() - withContext(Dispatchers.Default) { - withTimeout(5_000L) { - while (server.rekeyCount < 1) delay(50) - } - } + server.rekeyCount.first { it >= 1 } - assertEquals(1, server.rekeyCount) + assertEquals(1, server.rekeyCount.value) // Packet processing should continue after re-key with no disconnect. server.sendIgnore() - withContext(Dispatchers.Default) { delay(200) } val disconnectCause = withTimeoutOrNull(250) { connection.disconnectedFlow.first() } assertNull(disconnectCause) } finally { connection.close() - scope.cancel() } } }