Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sshlib/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ package org.connectbot.sshlib.client {
}

public final class SshConnection {
ctor public SshConnection(org.connectbot.sshlib.transport.Transport transport, optional java.lang.String clientVersion, org.connectbot.sshlib.HostKeyVerifier hostKeyVerifier, optional java.lang.String kexAlgorithms, optional java.lang.String hostKeyAlgorithms, optional java.lang.String encryptionAlgorithms, optional java.lang.String macAlgorithms, optional java.lang.String compressionAlgorithms, optional boolean preferPasswordAuth, optional long rekeyIntervalMs, optional long rekeyBytesLimit, optional kotlin.coroutines.CoroutineContext coroutineContext);
ctor public SshConnection(org.connectbot.sshlib.transport.Transport transport, optional java.lang.String clientVersion, org.connectbot.sshlib.HostKeyVerifier hostKeyVerifier, optional java.lang.String kexAlgorithms, optional java.lang.String hostKeyAlgorithms, optional java.lang.String encryptionAlgorithms, optional java.lang.String macAlgorithms, optional java.lang.String compressionAlgorithms, optional boolean preferPasswordAuth, optional long rekeyIntervalMs, optional long rekeyBytesLimit, optional kotlinx.coroutines.CoroutineDispatcher coroutineDispatcher);
method public suspend java.lang.Object? authenticateKeyboardInteractive(java.lang.String username, org.connectbot.sshlib.KeyboardInteractiveCallback callback, kotlin.coroutines.Continuation<? super org.connectbot.sshlib.AuthResult>);
method public suspend java.lang.Object? authenticatePassword(java.lang.String username, java.lang.String password, kotlin.coroutines.Continuation<? super org.connectbot.sshlib.AuthResult>);
method public suspend java.lang.Object? close(kotlin.coroutines.Continuation<? super kotlin.Unit>);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.connectbot.sshlib.client
import io.kaitai.struct.ByteBufferKaitaiStream
import io.kaitai.struct.KaitaiStruct
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CloseableCoroutineDispatcher
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
Expand All @@ -36,7 +38,6 @@ import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
Expand Down Expand Up @@ -140,7 +141,7 @@ import org.connectbot.sshlib.AuthResult as PublicAuthResult
* @param transport Underlying transport (e.g., TCP socket)
* @param clientVersion Client version string (default: SSH-2.0-CBSSH_1.0)
*/
@OptIn(ExperimentalCoroutinesApi::class)
@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
class SshConnection(
private val transport: Transport,
private val clientVersion: String = "SSH-2.0-CBSSH_1.0",
Expand All @@ -153,20 +154,19 @@ class SshConnection(
private val preferPasswordAuth: Boolean = false,
private val rekeyIntervalMs: Long = 3_600_000L,
private val rekeyBytesLimit: Long = 1_073_741_824L,
coroutineContext: CoroutineContext = Dispatchers.IO,
coroutineDispatcher: CoroutineDispatcher = Dispatchers.IO,
) {

companion object {
private val logger = LoggerFactory.getLogger(SshConnection::class.java)
}

private val stateMachineDispatcher = coroutineDispatcher.limitedParallelism(1, "StateMachine")

private class HostKeyRejectedException(val key: PublicKey) : Exception("Host key rejected")

private val packetIO = PacketIO(transport)

@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
private val stateMachineDispatcher = newSingleThreadContext("ssh-state-machine")

private val callbacks = object : SshClientCallbacks {
override fun sendVersion() = this@SshConnection.sendVersion()
override fun receiveVersion(banner: IdBanner) = this@SshConnection.receiveVersion(banner)
Expand Down Expand Up @@ -205,7 +205,7 @@ class SshConnection(
}

private val stateMachine = SshClientStateMachine(callbacks)
private val connectionScope = CoroutineScope(SupervisorJob() + coroutineContext)
private val connectionScope = CoroutineScope(SupervisorJob() + coroutineDispatcher)
private val writeMutex = Mutex()

private val _disconnectedFlow = MutableSharedFlow<Throwable?>(extraBufferCapacity = 1)
Expand Down Expand Up @@ -371,9 +371,7 @@ class SshConnection(
// Start packet loop — handles all binary SSH packets from here
startPacketLoop()

withTimeout(30_000L) {
deferred.await()
}
withTimeout(30_000L) { deferred.await() }
} catch (e: TimeoutCancellationException) {
ConnectResult.TransportError(Exception("Connection timed out"))
} catch (e: HostKeyRejectedException) {
Expand Down Expand Up @@ -1677,7 +1675,6 @@ class SshConnection(
transport.close()
packetLoopJob?.join()
packetLoopJob = null
stateMachineDispatcher.close()
sessionId?.fill(0)
sessionId = null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package org.connectbot.sshlib.client
import io.kaitai.struct.ByteBufferKaitaiStream
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.selects.select
import kotlinx.coroutines.sync.Mutex
Expand Down Expand Up @@ -47,14 +51,16 @@ import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.MessageDigest
import java.security.SecureRandom
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext

class FakeSshServer(
private val serverTransport: PipedTransport,
private val scope: CoroutineScope,
private val coroutineContext: CoroutineContext = EmptyCoroutineContext,
) {
@Volatile
var rekeyCount = 0
private set
private val _rekeyCount = MutableStateFlow(0)
val rekeyCount: StateFlow<Int> = _rekeyCount.asStateFlow()

private val hostKeyPair: KeyPair = KeyPairGenerator.getInstance("Ed25519").generateKeyPair()
private val hostKeyBlob: ByteArray = SshPublicKeyEncoder.encode(hostKeyPair.public, "ssh-ed25519")
Expand All @@ -71,7 +77,7 @@ class FakeSshServer(
private val rekeyRequestChannel = Channel<Unit>(Channel.UNLIMITED)

fun start() {
scope.launch { serve() }
scope.launch(coroutineContext) { serve() }
}

/**
Expand All @@ -83,7 +89,7 @@ class FakeSshServer(
}

fun sendIgnore() {
scope.launch {
scope.launch(coroutineContext) {
val msg = org.connectbot.sshlib.protocol.SshMsgIgnore().apply {
setData(createByteString(byteArrayOf()))
_check()
Expand Down Expand Up @@ -185,7 +191,7 @@ class FakeSshServer(
writeMutex.withLock { io.writePacket(SshEnums.MessageType.SSH_MSG_NEWKEYS.id().toInt()) }
packets.receive() // client NEWKEYS
activateEncryption(io)
rekeyCount++
_rekeyCount.update { it + 1 }
}

private suspend fun doClientInitiatedKex(
Expand All @@ -200,7 +206,7 @@ class FakeSshServer(
writeMutex.withLock { io.writePacket(SshEnums.MessageType.SSH_MSG_NEWKEYS.id().toInt()) }
packets.receive() // client NEWKEYS
activateEncryption(io)
rekeyCount++
_rekeyCount.update { it + 1 }
}

private suspend fun sendKexInit(io: PacketIO): ByteArray {
Expand Down
91 changes: 40 additions & 51 deletions sshlib/src/test/kotlin/org/connectbot/sshlib/client/RekeyTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,23 @@

package org.connectbot.sshlib.client

import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.coroutines.yield
import org.connectbot.sshlib.ConnectResult
import org.connectbot.sshlib.HostKeyVerifier
import org.connectbot.sshlib.PublicKey
import org.connectbot.sshlib.transport.PipedTransport
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import kotlin.test.assertIs

Expand All @@ -45,149 +43,140 @@ class RekeyTest {
override suspend fun verify(key: PublicKey): Boolean = true
}

private suspend fun connectInBackground(
connection: SshConnection,
backgroundScope: CoroutineScope,
dispatcher: CoroutineDispatcher,
): ConnectResult {
val result = CompletableDeferred<ConnectResult>()
backgroundScope.launch(dispatcher) { result.complete(connection.connect()) }
yield()
return result.await()
}

@Test
fun `rekey triggered when byte limit exceeded`() = runTest {
val scope = CoroutineScope(Dispatchers.Default)
val dispatcher = StandardTestDispatcher(testScheduler)
val (clientTransport, serverTransport) = PipedTransport.create()
val server = FakeSshServer(serverTransport, scope)
val server = FakeSshServer(serverTransport, backgroundScope, dispatcher)
server.start()

val connection = SshConnection(
transport = clientTransport,
hostKeyVerifier = acceptAllVerifier,
rekeyBytesLimit = 512L,
rekeyIntervalMs = Long.MAX_VALUE,
coroutineDispatcher = dispatcher,
)

try {
val result = withContext(Dispatchers.Default) { connection.connect() }
val result = connectInBackground(connection, backgroundScope, dispatcher)
assertIs<ConnectResult.Success>(result)

server.sendIgnore()
withContext(Dispatchers.Default) {
withTimeout(5_000L) {
while (server.rekeyCount < 1) delay(50)
}
}

assertTrue(server.rekeyCount >= 1)
server.rekeyCount.first { it >= 1 }
Comment thread
kruton marked this conversation as resolved.

// Verify connection is still alive after re-key
server.sendIgnore()
Comment thread
kruton marked this conversation as resolved.
withContext(Dispatchers.Default) { delay(200) }
val disconnectCause = withTimeoutOrNull(250) { connection.disconnectedFlow.first() }
assertNull(disconnectCause)
} finally {
connection.close()
scope.cancel()
}
}

@Test
fun `rekey triggered after interval elapses`() = runTest {
val scope = CoroutineScope(Dispatchers.Default)
val dispatcher = StandardTestDispatcher(testScheduler)
val (clientTransport, serverTransport) = PipedTransport.create()
val server = FakeSshServer(serverTransport, scope)
val server = FakeSshServer(serverTransport, backgroundScope, dispatcher)
server.start()

val connection = SshConnection(
transport = clientTransport,
hostKeyVerifier = acceptAllVerifier,
rekeyIntervalMs = 3_600_000L,
rekeyBytesLimit = Long.MAX_VALUE,
coroutineContext = dispatcher,
coroutineDispatcher = dispatcher,
)

try {
val result = withContext(Dispatchers.Default) { connection.connect() }
val result = connectInBackground(connection, backgroundScope, dispatcher)
assertIs<ConnectResult.Success>(result)

assertEquals(0, server.rekeyCount)
assertEquals(0, server.rekeyCount.value)

advanceTimeBy(3_600_001L)

server.sendIgnore()
withContext(Dispatchers.Default) {
withTimeout(5_000L) {
while (server.rekeyCount < 1) delay(50)
}
}
server.rekeyCount.first { it >= 1 }
Comment thread
kruton marked this conversation as resolved.

assertEquals(1, server.rekeyCount)
assertEquals(1, server.rekeyCount.value)
} finally {
connection.close()
scope.cancel()
}
}

@Test
fun `server-initiated rekey is handled correctly`() = runTest {
val scope = CoroutineScope(Dispatchers.Default)
val dispatcher = StandardTestDispatcher(testScheduler)
val (clientTransport, serverTransport) = PipedTransport.create()
val server = FakeSshServer(serverTransport, scope)
val server = FakeSshServer(serverTransport, backgroundScope, dispatcher)
server.start()

val connection = SshConnection(
transport = clientTransport,
hostKeyVerifier = acceptAllVerifier,
rekeyIntervalMs = Long.MAX_VALUE,
rekeyBytesLimit = Long.MAX_VALUE,
coroutineDispatcher = dispatcher,
)

try {
val result = withContext(Dispatchers.Default) { connection.connect() }
val result = connectInBackground(connection, backgroundScope, dispatcher)
assertIs<ConnectResult.Success>(result)

server.initiateRekey()
withContext(Dispatchers.Default) {
withTimeout(5_000L) {
while (server.rekeyCount < 1) delay(50)
}
}
server.rekeyCount.first { it >= 1 }
Comment thread
kruton marked this conversation as resolved.

assertEquals(1, server.rekeyCount)
assertEquals(1, server.rekeyCount.value)
} finally {
connection.close()
scope.cancel()
}
}

@Test
fun `packets received during rekey are not dropped`() = runTest {
val scope = CoroutineScope(Dispatchers.Default)
val dispatcher = StandardTestDispatcher(testScheduler)
val (clientTransport, serverTransport) = PipedTransport.create()
val server = FakeSshServer(serverTransport, scope)
val server = FakeSshServer(serverTransport, backgroundScope, dispatcher)
server.start()

val connection = SshConnection(
transport = clientTransport,
hostKeyVerifier = acceptAllVerifier,
rekeyIntervalMs = Long.MAX_VALUE,
rekeyBytesLimit = Long.MAX_VALUE,
coroutineDispatcher = dispatcher,
)

try {
val result = withContext(Dispatchers.Default) { connection.connect() }
val result = connectInBackground(connection, backgroundScope, dispatcher)
assertIs<ConnectResult.Success>(result)

server.sendIgnore()
server.initiateRekey()
server.sendIgnore()
withContext(Dispatchers.Default) {
withTimeout(5_000L) {
while (server.rekeyCount < 1) delay(50)
}
}
server.rekeyCount.first { it >= 1 }
Comment thread
kruton marked this conversation as resolved.

assertEquals(1, server.rekeyCount)
assertEquals(1, server.rekeyCount.value)

// Packet processing should continue after re-key with no disconnect.
server.sendIgnore()
withContext(Dispatchers.Default) { delay(200) }
val disconnectCause = withTimeoutOrNull(250) { connection.disconnectedFlow.first() }
assertNull(disconnectCause)
} finally {
connection.close()
scope.cancel()
}
}
}
Loading