diff --git a/protocol/src/main/resources/kaitai/ssh_enums.ksy b/protocol/src/main/resources/kaitai/ssh_enums.ksy index cf86e3cf..c3ef51de 100644 --- a/protocol/src/main/resources/kaitai/ssh_enums.ksy +++ b/protocol/src/main/resources/kaitai/ssh_enums.ksy @@ -72,6 +72,8 @@ enums: 98: ssh_msg_channel_request 99: ssh_msg_channel_success 100: ssh_msg_channel_failure + 192: ssh_msg_ping + 193: ssh_msg_pong global_request_type: 0: empty_response 1: tcpip_forward diff --git a/protocol/src/main/resources/kaitai/ssh_msg_ping.ksy b/protocol/src/main/resources/kaitai/ssh_msg_ping.ksy new file mode 100644 index 00000000..8d06b960 --- /dev/null +++ b/protocol/src/main/resources/kaitai/ssh_msg_ping.ksy @@ -0,0 +1,12 @@ +meta: + id: ssh_msg_ping + endian: be + imports: + - byte_string +doc-ref: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL +doc: > + A ping message sent by either the client or server. The recipient must + reply with SSH_MSG_PONG containing the same data. +seq: +- id: data + type: byte_string diff --git a/protocol/src/main/resources/kaitai/ssh_msg_pong.ksy b/protocol/src/main/resources/kaitai/ssh_msg_pong.ksy new file mode 100644 index 00000000..7e11bc48 --- /dev/null +++ b/protocol/src/main/resources/kaitai/ssh_msg_pong.ksy @@ -0,0 +1,12 @@ +meta: + id: ssh_msg_pong + endian: be + imports: + - byte_string +doc-ref: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL +doc: > + A pong reply to SSH_MSG_PING. The data field must be an exact copy of + the data from the corresponding ping message. +seq: +- id: data + type: byte_string diff --git a/sshlib/api.txt b/sshlib/api.txt index dc2b3abf..b02e7cce 100644 --- a/sshlib/api.txt +++ b/sshlib/api.txt @@ -273,6 +273,45 @@ package org.connectbot.sshlib { method public suspend java.lang.Object? verify(org.connectbot.sshlib.PublicKey key, kotlin.coroutines.Continuation); } + public abstract sealed exhaustive class PingResult { + } + + public static final class PingResult.Failure extends org.connectbot.sshlib.PingResult { + ctor public PingResult.Failure(java.lang.Throwable cause); + method public java.lang.Throwable component1(); + method public org.connectbot.sshlib.PingResult.Failure copy(optional java.lang.Throwable cause); + method public boolean equals(java.lang.Object? other); + method @InaccessibleFromKotlin public java.lang.Throwable getCause(); + method public int hashCode(); + method public java.lang.String toString(); + property public Throwable cause; + } + + public static final class PingResult.NotAuthenticated extends org.connectbot.sshlib.PingResult { + method public boolean equals(java.lang.Object? other); + method public int hashCode(); + method public java.lang.String toString(); + field public static final org.connectbot.sshlib.PingResult.NotAuthenticated INSTANCE; + } + + public static final class PingResult.NotSupported extends org.connectbot.sshlib.PingResult { + method public boolean equals(java.lang.Object? other); + method public int hashCode(); + method public java.lang.String toString(); + field public static final org.connectbot.sshlib.PingResult.NotSupported INSTANCE; + } + + public static final class PingResult.Success extends org.connectbot.sshlib.PingResult { + ctor public PingResult.Success(long elapsedNs); + method public long component1(); + method public org.connectbot.sshlib.PingResult.Success copy(optional long elapsedNs); + method public boolean equals(java.lang.Object? other); + method @InaccessibleFromKotlin public long getElapsedNs(); + method public int hashCode(); + method public java.lang.String toString(); + property public long elapsedNs; + } + public interface PortForwarder { method public default void close(); method @InaccessibleFromKotlin public java.lang.String getBoundHost(); @@ -498,6 +537,7 @@ package org.connectbot.sshlib { method public org.connectbot.sshlib.transport.TransportFactory? openDirectTcpipTransport(java.lang.String remoteHost, int remotePort, optional java.lang.String originAddr, optional int originPort); method public suspend java.lang.Object? openSession(kotlin.coroutines.Continuation); method public suspend java.lang.Object? openSftp(kotlin.coroutines.Continuation>); + method public suspend java.lang.Object? ping(kotlin.coroutines.Continuation); method public suspend java.lang.Object? remotePortForward(java.lang.String remoteBindAddress, int remoteBindPort, java.lang.String localHost, int localPort, kotlin.coroutines.Continuation); property public org.connectbot.sshlib.ConnectionInfo? connectionInfo; property public kotlinx.coroutines.flow.SharedFlow disconnectedFlow; @@ -675,6 +715,7 @@ package org.connectbot.sshlib.blocking { method public org.connectbot.sshlib.transport.TransportFactory? openDirectTcpipTransport(java.lang.String remoteHost, int remotePort, optional java.lang.String originAddr, optional int originPort); method public org.connectbot.sshlib.SshSession? openSession(); method @kotlin.jvm.Throws(exceptionClasses=SftpException::class) public org.connectbot.sshlib.SftpClient openSftp() throws org.connectbot.sshlib.SftpException; + method public org.connectbot.sshlib.PingResult ping(); method public org.connectbot.sshlib.PortForwarder? remotePortForward(java.lang.String remoteBindAddress, int remoteBindPort, java.lang.String localHost, int localPort); property public kotlinx.coroutines.flow.SharedFlow disconnectedFlow; property public boolean isAuthenticated; diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/PingResult.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/PingResult.kt new file mode 100644 index 00000000..6ab9ff2b --- /dev/null +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/PingResult.kt @@ -0,0 +1,18 @@ +package org.connectbot.sshlib + +/** + * Result of a [SshClient.ping] call. + */ +sealed class PingResult { + /** The server replied; [elapsedNs] is the round-trip time in nanoseconds. */ + data class Success(val elapsedNs: Long) : PingResult() + + /** The server did not advertise ping support via SSH2_MSG_EXT_INFO. */ + data object NotSupported : PingResult() + + /** There is no active authenticated connection. */ + data object NotAuthenticated : PingResult() + + /** An error occurred while sending the ping or waiting for the reply. */ + data class Failure(val cause: Throwable) : PingResult() +} diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/SshClient.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/SshClient.kt index eb08fa7c..2097e87c 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/SshClient.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/SshClient.kt @@ -26,6 +26,7 @@ import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.flow.asSharedFlow import kotlinx.coroutines.launch +import org.connectbot.sshlib.PingResult import org.connectbot.sshlib.client.DynamicPortForwarder import org.connectbot.sshlib.client.LocalPortForwarder import org.connectbot.sshlib.client.RemotePortForwarder @@ -652,6 +653,25 @@ class SshClient private constructor( } } + /** + * Send an SSH ping to the server and return the round-trip time. + * + * Requires a prior successful [connect] and authentication. Returns + * [PingResult.NotAuthenticated] if there is no active connection or + * authentication has not completed, [PingResult.NotSupported] if the server + * did not advertise `ping@openssh.com` support via SSH2_MSG_EXT_INFO, or + * [PingResult.Failure] if the ping cannot be sent or the connection closes + * before the server replies. + * + * @return [PingResult.Success] with round-trip nanoseconds, [PingResult.NotSupported], + * [PingResult.NotAuthenticated], or [PingResult.Failure] + */ + suspend fun ping(): PingResult { + val conn = connection ?: return PingResult.NotAuthenticated + if (!authenticated) return PingResult.NotAuthenticated + return conn.ping() + } + /** * Disconnect from the SSH server. */ diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/blocking/BlockingSshClient.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/blocking/BlockingSshClient.kt index 40b181c6..e60bb063 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/blocking/BlockingSshClient.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/blocking/BlockingSshClient.kt @@ -26,6 +26,7 @@ import org.connectbot.sshlib.AuthResult import org.connectbot.sshlib.ConnectResult import org.connectbot.sshlib.HostKeyVerifier import org.connectbot.sshlib.KeyboardInteractiveCallback +import org.connectbot.sshlib.PingResult import org.connectbot.sshlib.PortForwarder import org.connectbot.sshlib.SftpClient import org.connectbot.sshlib.SftpException @@ -247,6 +248,11 @@ class BlockingSshClient internal constructor( @Throws(SftpException::class) fun openSftp(): SftpClient = runBlocking { client.openSftp().getOrThrow() } + /** + * Send an SSH ping to the server and return the result. + */ + fun ping(): PingResult = runBlocking { client.ping() } + /** * Start local port forwarding. */ diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt index bd36f963..c81d2960 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt @@ -27,6 +27,7 @@ import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.Job +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.cancel @@ -49,6 +50,7 @@ import org.connectbot.sshlib.ConnectResult import org.connectbot.sshlib.ConnectionInfo import org.connectbot.sshlib.HostKeyVerifier import org.connectbot.sshlib.KeyboardInteractiveCallback +import org.connectbot.sshlib.PingResult import org.connectbot.sshlib.PublicKey import org.connectbot.sshlib.SshException import org.connectbot.sshlib.crypto.CipherEntry @@ -100,6 +102,8 @@ import org.connectbot.sshlib.protocol.SshMsgKexEcdhReply import org.connectbot.sshlib.protocol.SshMsgKexdhInit import org.connectbot.sshlib.protocol.SshMsgKexdhReply import org.connectbot.sshlib.protocol.SshMsgKexinit +import org.connectbot.sshlib.protocol.SshMsgPing +import org.connectbot.sshlib.protocol.SshMsgPong import org.connectbot.sshlib.protocol.SshMsgServiceAccept import org.connectbot.sshlib.protocol.SshMsgServiceRequest import org.connectbot.sshlib.protocol.SshMsgUserauthBanner @@ -126,9 +130,11 @@ import org.connectbot.sshlib.transport.PacketIO import org.connectbot.sshlib.transport.Transport import org.slf4j.LoggerFactory import java.math.BigInteger +import java.nio.ByteBuffer import java.security.SecureRandom import java.util.Collections import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicLong import kotlin.coroutines.CoroutineContext import org.connectbot.sshlib.AuthResult as PublicAuthResult @@ -146,7 +152,7 @@ class SshConnection( private val transport: Transport, private val clientVersion: String = "SSH-2.0-CBSSH_1.0", private val hostKeyVerifier: HostKeyVerifier, - private val kexAlgorithms: String = KexEntry.defaultString, + kexAlgorithms: String = KexEntry.defaultString, private val hostKeyAlgorithms: String = SignatureEntry.defaultString, private val encryptionAlgorithms: String = CipherEntry.defaultString, private val macAlgorithms: String = MacEntry.defaultString, @@ -159,8 +165,26 @@ class SshConnection( companion object { private val logger = LoggerFactory.getLogger(SshConnection::class.java) + + private fun stripExtInfoC(kexAlgorithms: String): String = kexAlgorithms.split(",") + .filter { it.isNotEmpty() && it != "ext-info-c" } + .joinToString(",") + + private fun appendExtInfoC(kexAlgorithms: String): String { + val algorithms = kexAlgorithms.split(",").filter { it.isNotEmpty() } + return if ("ext-info-c" in algorithms) { + kexAlgorithms + } else { + (algorithms + "ext-info-c").joinToString(",") + } + } + + private fun parseNameList(nameList: String): List = nameList.split(",").filter { it.isNotEmpty() } } + private val kexAlgorithms: String = stripExtInfoC(kexAlgorithms) + private val initialKexAlgorithms: String = appendExtInfoC(this.kexAlgorithms) + private val stateMachineDispatcher = coroutineDispatcher.limitedParallelism(1, "StateMachine") private class HostKeyRejectedException(val key: PublicKey) : Exception("Host key rejected") @@ -183,6 +207,7 @@ class SshConnection( override suspend fun sendNewKeys() = this@SshConnection.sendNewKeys() override fun receiveNewKeys() = this@SshConnection.receiveNewKeys() override suspend fun activateEncryption() = this@SshConnection.activateEncryption() + override suspend fun sendClientExtInfo() = this@SshConnection.sendClientExtInfo() override suspend fun sendServiceRequest(service: String) = this@SshConnection.sendServiceRequest(service) override fun receiveServiceAccept(service: String) = this@SshConnection.receiveServiceAccept(service) override fun startAuthentication() = this@SshConnection.startAuthentication() @@ -205,7 +230,7 @@ class SshConnection( } private val stateMachine = SshClientStateMachine(callbacks) - private val connectionScope = CoroutineScope(SupervisorJob() + coroutineDispatcher) + internal val connectionScope = CoroutineScope(SupervisorJob() + coroutineDispatcher) private val writeMutex = Mutex() private val _disconnectedFlow = MutableSharedFlow(extraBufferCapacity = 1) @@ -230,6 +255,9 @@ class SshConnection( private var negotiatedCompressionC2S: String? = null private var negotiatedCompressionS2C: String? = null private var strictKexEnabled: Boolean = false + private var serverAdvertisesExtInfo: Boolean = false + private var serverExtInfoReceivedCount: Int = 0 + private var clientExtInfoSent: Boolean = false private var nextLocalChannelNumber = 0 private val channelNumberLock = Mutex() @@ -242,8 +270,20 @@ class SshConnection( private var agentProvider: AgentProvider? = null private var serverHostKeyBlob: ByteArray? = null - private var serverAdvertisesHostBound: Boolean = false - private var serverSigAlgs: Set? = null + + @Volatile private var serverAdvertisesHostBound: Boolean = false + + @Volatile private var serverSigAlgs: Set? = null + + @Volatile internal var serverSupportsPing: Boolean = false + private val pingSequence = AtomicLong(0) + private data class PendingPing( + val deferred: CompletableDeferred, + val payload: ByteArray, + val sentTimeNs: Long? = null, + ) + private val pendingPings = HashMap() + private val pendingPingQueue = ArrayDeque Unit>() /** * Helper to manage a pending asynchronous operation that waits for a server response. @@ -323,7 +363,7 @@ class SshConnection( private var packetLoopJob: Job? = null - @Volatile private var isRekeying = false + @Volatile internal var isRekeying = false private var rekeyTimerJob: Job? = null @@ -962,6 +1002,7 @@ class SshConnection( private suspend fun sendKexInit() { logger.debug("Sending KEX_INIT") + val localKexAlgorithms = if (isRekeying) kexAlgorithms else initialKexAlgorithms val kexInit = SshMsgKexinit().apply { // Cookie (16 random bytes) @@ -970,7 +1011,7 @@ class SshConnection( } setCookie(cookie) - setKexAlgorithms(createNameList(kexAlgorithms)) + setKexAlgorithms(createNameList(localKexAlgorithms)) setServerHostKeyAlgorithms(createNameList(hostKeyAlgorithms)) setEncryptionAlgorithmsClientToServer(createNameList(encryptionAlgorithms)) setEncryptionAlgorithmsServerToClient(createNameList(encryptionAlgorithms)) @@ -1013,16 +1054,25 @@ class SshConnection( logger.debug(" Server compression c->s: $serverCompC2S") logger.debug(" Server compression s->c: $serverCompS2C") - val clientKexStrict = kexAlgorithms.contains("kex-strict-c-v00@openssh.com") - val serverKexStrict = serverKexAlgs.contains("kex-strict-s-v00@openssh.com") + val localKexAlgorithms = if (isRekeying) kexAlgorithms else initialKexAlgorithms + val clientKexList = parseNameList(localKexAlgorithms) + val serverKexList = serverKexAlgs.filter { it.isNotEmpty() } + val clientKexStrict = "kex-strict-c-v00@openssh.com" in clientKexList + val serverKexStrict = "kex-strict-s-v00@openssh.com" in serverKexList strictKexEnabled = clientKexStrict && serverKexStrict if (strictKexEnabled) { logger.info(" Strict KEX enabled") } - val clientKexList = kexAlgorithms.split(",") - negotiatedKex = clientKexList.firstOrNull { it in serverKexAlgs } - ?: throw SshException("No matching KEX algorithm. Client: $kexAlgorithms, Server: $serverKexAlgs") + if (!isRekeying) { + serverAdvertisesExtInfo = "ext-info-s" in serverKexList + if (serverAdvertisesExtInfo) { + logger.info(" Server advertises EXT_INFO support") + } + } + + negotiatedKex = clientKexList.firstOrNull { it in serverKexList } + ?: throw SshException("No matching KEX algorithm. Client: $localKexAlgorithms, Server: $serverKexAlgs") logger.info(" Negotiated KEX: $negotiatedKex") val clientHostKeyList = hostKeyAlgorithms.split(",") @@ -1140,9 +1190,17 @@ class SshConnection( private fun rekeyComplete() { logger.info("Re-key complete") - isRekeying = false packetIO.resetByteCounters() - startRekeyTimer() + connectionScope.launch { + withContext(stateMachineDispatcher) { + while (pendingPingQueue.isNotEmpty()) { + val action = pendingPingQueue.removeFirst() + action() + } + isRekeying = false + startRekeyTimer() + } + } } private fun startRekeyTimer() { @@ -1285,6 +1343,84 @@ class SshConnection( } } + private suspend fun sendClientExtInfo() { + if (clientExtInfoSent) { + logger.info("Skipping client SSH_MSG_EXT_INFO; already sent during initial key exchange") + return + } + if (isRekeying) { + logger.info("Skipping client SSH_MSG_EXT_INFO during re-key") + return + } + if (!serverAdvertisesExtInfo) { + logger.info("Skipping client SSH_MSG_EXT_INFO; server did not advertise ext-info-s") + return + } + logger.info("Sending client SSH_MSG_EXT_INFO") + val msg = SshMsgExtInfo() + val extensions = listOf( + "ext-info-in-auth@openssh.com", + ) + msg.setNumExtensions(extensions.size.toLong()) + msg.setExtensions( + extensions.mapTo(ArrayList()) { name -> + SshMsgExtInfo.Extension().apply { + set_root(msg) + set_parent(msg) + setExtensionName(createAsciiString(name)) + setExtensionValue(createByteString("0".toByteArray(Charsets.US_ASCII))) + _check() + } + }, + ) + msg._check() + writePacket(SshEnums.MessageType.SSH_MSG_EXT_INFO.id().toInt(), msg.toByteArray()) + clientExtInfoSent = true + } + + private fun processServerExtInfo(extInfo: SshMsgExtInfo) { + if (!serverAdvertisesExtInfo) { + logger.warn("Ignoring SSH_MSG_EXT_INFO because server did not advertise ext-info-s") + return + } + if (serverExtInfoReceivedCount >= 2) { + logger.warn("Ignoring unexpected extra SSH_MSG_EXT_INFO from server") + return + } + val initialExtInfo = serverExtInfoReceivedCount == 0 + serverExtInfoReceivedCount++ + + if (initialExtInfo) { + serverAdvertisesHostBound = false + serverSigAlgs = null + serverSupportsPing = false + } + + for (ext in extInfo.extensions()) { + when (ext.extensionName().value()) { + "publickey-hostbound@openssh.com" -> { + if (initialExtInfo) { + serverAdvertisesHostBound = true + logger.info("Server advertises publickey-hostbound@openssh.com") + } + } + + "server-sig-algs" -> { + val algs = String(ext.extensionValue().data(), Charsets.UTF_8) + serverSigAlgs = algs.split(",").filter { it.isNotEmpty() }.toSet() + logger.info("Server advertises server-sig-algs: $algs") + } + + "ping@openssh.com" -> { + if (initialExtInfo) { + serverSupportsPing = true + logger.info("Server advertises ping@openssh.com") + } + } + } + } + } + private fun receiveNewKeys() { logger.info("Received NEW_KEYS from server") if (strictKexEnabled) { @@ -1671,10 +1807,20 @@ class SshConnection( } suspend fun close() { - connectionScope.cancel() transport.close() + connectionScope.cancel() packetLoopJob?.join() packetLoopJob = null + + withContext(stateMachineDispatcher) { + val error = Exception("Connection closed") + for ((_, pending) in pendingPings) { + pending.deferred.complete(PingResult.Failure(error)) + } + pendingPings.clear() + pendingPingQueue.clear() + } + sessionId?.fill(0) sessionId = null } @@ -2155,20 +2301,51 @@ class SshConnection( } SshEnums.MessageType.SSH_MSG_EXT_INFO -> { - val extInfo = parseBody(packet) - for (ext in extInfo.extensions()) { - when (ext.extensionName().value()) { - "publickey-hostbound@openssh.com" -> { - serverAdvertisesHostBound = true - logger.info("Server advertises publickey-hostbound@openssh.com") - } + processServerExtInfo(parseBody(packet)) + } + + SshEnums.MessageType.SSH_MSG_PING -> { + val msg = parseBody(packet) + val pongSend: suspend () -> Unit = { + val pong = SshMsgPong() + pong.setData(createByteString(msg.data().data())) + pong._check() + writePacket(SshEnums.MessageType.SSH_MSG_PONG.id().toInt(), pong.toByteArray()) + } + val sendNow = withContext(stateMachineDispatcher) { + if (isRekeying) { + pendingPingQueue.addLast(pongSend) + false + } else { + true + } + } + if (sendNow) { + pongSend() + } + } - "server-sig-algs" -> { - val algs = String(ext.extensionValue().data(), Charsets.UTF_8) - serverSigAlgs = algs.split(",").filter { it.isNotEmpty() }.toSet() - logger.info("Server advertises server-sig-algs: $algs") + SshEnums.MessageType.SSH_MSG_PONG -> { + val msg = parseBody(packet) + val seqBytes = msg.data().data() + if (seqBytes.size == 8) { + val seq = ByteBuffer.wrap(seqBytes).getLong() + withContext(stateMachineDispatcher) { + val pending = pendingPings.remove(seq) + if (pending != null) { + val sentTimeNs = pending.sentTimeNs + if (sentTimeNs != null) { + pending.deferred.complete(PingResult.Success(System.nanoTime() - sentTimeNs)) + } else { + pendingPings[seq] = pending + logger.warn("Received SSH_MSG_PONG before ping send timestamp was recorded: $seq") + } + } else { + logger.warn("Received SSH_MSG_PONG with unknown sequence: $seq") } } + } else { + logger.warn("Received SSH_MSG_PONG with unexpected data length: ${seqBytes.size}") } } @@ -2350,6 +2527,12 @@ class SshConnection( pending.deferred.completeExceptionally(loopError) } pendingChannelOpens.clear() + + for ((_, pending) in pendingPings) { + pending.deferred.complete(PingResult.Failure(loopError)) + } + pendingPings.clear() + pendingPingQueue.clear() } if (loopException != null) { _disconnectedFlow.tryEmit(loopException) @@ -2482,6 +2665,56 @@ class SshConnection( ) } + internal suspend fun ping(): PingResult { + if (!serverSupportsPing) return PingResult.NotSupported + + val seq = pingSequence.getAndIncrement() + val data = ByteBuffer.allocate(8).putLong(seq).array() + val deferred = CompletableDeferred() + + val send: suspend () -> Unit = { + try { + writeMutex.withLock { + val current = pendingPings[seq] ?: return@withLock + val ping = SshMsgPing() + ping.setData(createByteString(current.payload)) + ping._check() + + val sentTimeNs = System.nanoTime() + pendingPings[seq] = current.copy(sentTimeNs = sentTimeNs) + packetIO.writePacket(SshEnums.MessageType.SSH_MSG_PING.id().toInt(), ping.toByteArray()) + } + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + withContext(stateMachineDispatcher) { + if (pendingPings.remove(seq) != null) { + deferred.complete(PingResult.Failure(e)) + } + } + } + } + + withContext(stateMachineDispatcher) { + pendingPings[seq] = PendingPing(deferred, data) + if (isRekeying) { + pendingPingQueue.addLast(send) + } else { + send() + } + } + + return try { + deferred.await() + } finally { + withContext(NonCancellable) { + withContext(stateMachineDispatcher) { + pendingPings.remove(seq) + } + } + } + } + internal val connectionInfo: ConnectionInfo? get() { val kex = negotiatedKex ?: return null diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/protocol/SshClientStateMachine.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/protocol/SshClientStateMachine.kt index 02991d48..099b2042 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/protocol/SshClientStateMachine.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/protocol/SshClientStateMachine.kt @@ -271,6 +271,7 @@ internal class SshClientStateMachine( onTriggered { callbacks.receiveNewKeys() callbacks.activateEncryption() + callbacks.sendClientExtInfo() callbacks.sendServiceRequest("ssh-userauth") } } @@ -351,6 +352,7 @@ internal interface SshClientCallbacks { suspend fun sendNewKeys() fun receiveNewKeys() suspend fun activateEncryption() + suspend fun sendClientExtInfo() suspend fun sendServiceRequest(service: String) fun receiveServiceAccept(service: String) fun startAuthentication() diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/PingPongMessageTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/PingPongMessageTest.kt new file mode 100644 index 00000000..5b0ebad6 --- /dev/null +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/PingPongMessageTest.kt @@ -0,0 +1,68 @@ +package org.connectbot.sshlib + +import io.kaitai.struct.ByteBufferKaitaiStream +import org.connectbot.sshlib.protocol.SshMsgPing +import org.connectbot.sshlib.protocol.SshMsgPong +import org.connectbot.sshlib.protocol.createByteString +import org.connectbot.sshlib.protocol.toByteArray +import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Test + +class PingPongMessageTest { + + private fun buildPing(data: ByteArray): ByteArray { + val msg = SshMsgPing() + msg.setData(createByteString(data)) + msg._check() + return msg.toByteArray() + } + + private fun parsePing(bytes: ByteArray): SshMsgPing { + val msg = SshMsgPing(ByteBufferKaitaiStream(bytes)) + msg._read() + return msg + } + + private fun buildPong(data: ByteArray): ByteArray { + val msg = SshMsgPong() + msg.setData(createByteString(data)) + msg._check() + return msg.toByteArray() + } + + private fun parsePong(bytes: ByteArray): SshMsgPong { + val msg = SshMsgPong(ByteBufferKaitaiStream(bytes)) + msg._read() + return msg + } + + @Test + fun `ping round-trip with empty data`() { + val bytes = buildPing(byteArrayOf()) + val parsed = parsePing(bytes) + assertArrayEquals(byteArrayOf(), parsed.data().data()) + } + + @Test + fun `ping round-trip with non-empty data`() { + val data = byteArrayOf(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01) + val bytes = buildPing(data) + val parsed = parsePing(bytes) + assertArrayEquals(data, parsed.data().data()) + } + + @Test + fun `pong round-trip preserves data exactly`() { + val data = byteArrayOf(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2A) + val bytes = buildPong(data) + val parsed = parsePong(bytes) + assertArrayEquals(data, parsed.data().data()) + } + + @Test + fun `pong with empty data round-trips`() { + val bytes = buildPong(byteArrayOf()) + val parsed = parsePong(bytes) + assertArrayEquals(byteArrayOf(), parsed.data().data()) + } +} diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/ExtInfoNegotiationTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/ExtInfoNegotiationTest.kt new file mode 100644 index 00000000..c560da59 --- /dev/null +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/ExtInfoNegotiationTest.kt @@ -0,0 +1,205 @@ +package org.connectbot.sshlib.client + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeout +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.coroutines.yield +import org.connectbot.sshlib.ConnectResult +import org.connectbot.sshlib.HostKeyVerifier +import org.connectbot.sshlib.PublicKey +import org.connectbot.sshlib.protocol.SshMsgExtInfo +import org.connectbot.sshlib.transport.PipedTransport +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue + +@OptIn(ExperimentalCoroutinesApi::class) +class ExtInfoNegotiationTest { + + private val acceptAllVerifier = object : HostKeyVerifier { + override suspend fun verify(key: PublicKey): Boolean = true + } + + private suspend fun connectInBackground( + connection: SshConnection, + backgroundScope: CoroutineScope, + dispatcher: CoroutineDispatcher, + ): ConnectResult { + val result = CompletableDeferred() + backgroundScope.launch(dispatcher) { result.complete(connection.connect()) } + yield() + return result.await() + } + + private fun extensionNames(extInfo: SshMsgExtInfo): Set = extInfo.extensions().map { it.extensionName().value() }.toSet() + + private fun extensionValues(extInfo: SshMsgExtInfo): Map = extInfo.extensions().associate { it.extensionName().value() to it.extensionValue().data() } + + @Suppress("UNCHECKED_CAST") + private fun serverSigAlgs(connection: SshConnection): Set? { + val field = SshConnection::class.java.getDeclaredField("serverSigAlgs") + field.isAccessible = true + return field.get(connection) as Set? + } + + private suspend fun awaitServerSigAlgs(connection: SshConnection, expected: Set) { + withTimeout(1000) { + while (serverSigAlgs(connection) != expected) { + yield() + } + } + } + + @Test + fun `client sends EXT_INFO when server advertises ext-info-s`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertiseExtInfo = true + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + val extInfo = withTimeoutOrNull(1000) { server.awaitExtInfo() } + assertNotNull(extInfo, "Expected EXT_INFO from client") + val names = extensionNames(extInfo!!) + assertTrue("ext-info-in-auth@openssh.com" in names) + assertEquals(setOf("ext-info-in-auth@openssh.com"), names) + val values = extensionValues(extInfo) + assertContentEquals("0".toByteArray(Charsets.US_ASCII), values["ext-info-in-auth@openssh.com"]) + } finally { + connection.close() + } + } + + @Test + fun `client does NOT send EXT_INFO when server does not advertise ext-info-s`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertiseExtInfo = false + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + val extInfo = withTimeoutOrNull(1000) { server.awaitExtInfo() } + assertNull(extInfo, "Did NOT expect EXT_INFO from client when server doesn't advertise it") + } finally { + connection.close() + } + } + + @Test + fun `client does NOT send EXT_INFO when server kex only contains ext-info-s as substring`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.kexAlgorithms = "curve25519-sha256,not-ext-info-s" + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + val extInfo = withTimeoutOrNull(1000) { server.awaitExtInfo() } + assertNull(extInfo, "Did NOT expect EXT_INFO unless server advertises ext-info-s as an exact kex name") + } finally { + connection.close() + } + } + + @Test + fun `client appends ext-info-c to custom kex algorithms`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertiseExtInfo = true + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + kexAlgorithms = "curve25519-sha256", + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + val extInfo = withTimeoutOrNull(1000) { server.awaitExtInfo() } + assertNotNull(extInfo, "Expected EXT_INFO from client after appending ext-info-c") + } finally { + connection.close() + } + } + + @Test + fun `server may update server-sig-algs during user authentication`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertiseExtInfo = true + server.advertisePing = true + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + coroutineDispatcher = dispatcher, + ) + + var authJob: kotlinx.coroutines.Job? = null + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + assertTrue(connection.serverSupportsPing, "Initial EXT_INFO should advertise ping") + + authJob = backgroundScope.launch(dispatcher) { + connection.authenticatePassword("user", "pass") + } + withTimeout(1000) { server.awaitUserauthRequest() } + server.sendCustomExtInfo(mapOf("server-sig-algs" to "rsa-sha2-256".toByteArray(Charsets.US_ASCII))) + + awaitServerSigAlgs(connection, setOf("rsa-sha2-256")) + assertTrue(connection.serverSupportsPing, "In-auth EXT_INFO must not clear initial-only ping support") + assertEquals(setOf("rsa-sha2-256"), serverSigAlgs(connection)) + } finally { + authJob?.cancel() + connection.close() + } + } +} diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt index 5fea77a2..f5666d07 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt @@ -35,10 +35,14 @@ import org.connectbot.sshlib.crypto.SshPublicKeyEncoder import org.connectbot.sshlib.crypto.X25519ProviderFactory import org.connectbot.sshlib.crypto.encodeMpint import org.connectbot.sshlib.protocol.SshEnums +import org.connectbot.sshlib.protocol.SshMsgExtInfo import org.connectbot.sshlib.protocol.SshMsgKexEcdhInit import org.connectbot.sshlib.protocol.SshMsgKexEcdhReply import org.connectbot.sshlib.protocol.SshMsgKexinit +import org.connectbot.sshlib.protocol.SshMsgPing +import org.connectbot.sshlib.protocol.SshMsgPong import org.connectbot.sshlib.protocol.SshMsgServiceAccept +import org.connectbot.sshlib.protocol.SshMsgUserauthRequest import org.connectbot.sshlib.protocol.createAsciiString import org.connectbot.sshlib.protocol.createByteString import org.connectbot.sshlib.protocol.createNameList @@ -76,6 +80,13 @@ class FakeSshServer( private val rekeyRequestChannel = Channel(Channel.UNLIMITED) + var advertisePing: Boolean = false + var advertiseExtInfo: Boolean = false + var kexAlgorithms: String? = null + private val receivedPongs = Channel(Channel.UNLIMITED) + private val receivedExtInfo = Channel(Channel.UNLIMITED) + private val receivedUserauthRequests = Channel(Channel.UNLIMITED) + fun start() { scope.launch(coroutineContext) { serve() } } @@ -109,9 +120,32 @@ class FakeSshServer( // Initial KEX: reads directly from serverIo (no reader coroutine yet) doFullKex(serverIo) + sendExtInfo(serverIo) + + // After initial KEX, the client MAY send SSH_MSG_EXT_INFO, + // followed by a SERVICE_REQUEST (ssh-userauth). + var serviceRequest: ByteArray? = null + while (serviceRequest == null) { + val (msgType, rawBytes) = readPacketWithType(serverIo) + when (msgType) { + SshEnums.MessageType.SSH_MSG_EXT_INFO -> { + val bodyBytes = rawBytes.copyOfRange(1, rawBytes.size) + val extMsg = SshMsgExtInfo(ByteBufferKaitaiStream(bodyBytes)) + extMsg._read() + receivedExtInfo.trySend(extMsg) + } + + SshEnums.MessageType.SSH_MSG_SERVICE_REQUEST -> { + serviceRequest = rawBytes + } + + SshEnums.MessageType.SSH_MSG_DEBUG, + SshEnums.MessageType.SSH_MSG_IGNORE, + -> { /* skip */ } - // Service request - readPacketRaw(serverIo) + else -> throw IllegalStateException("Unexpected packet during handshake: $msgType") + } + } sendServiceAccept(serverIo) // After authentication, all packets are routed through this channel by the reader @@ -121,7 +155,11 @@ class FakeSshServer( val readerJob = scope.launch { try { while (true) { - incomingPackets.send(readPacketWithType(serverIo)) + val packet = readPacketWithType(serverIo) + if (packet.first == SshEnums.MessageType.SSH_MSG_NEWKEYS) { + activateEncryption(serverIo) + } + incomingPackets.send(packet) } } catch (_: Exception) { incomingPackets.close() @@ -144,6 +182,39 @@ class FakeSshServer( SshEnums.MessageType.SSH_MSG_DISCONNECT -> return@onReceiveCatching true + SshEnums.MessageType.SSH_MSG_EXT_INFO -> { + val bodyBytes = rawBytes.copyOfRange(1, rawBytes.size) + val extMsg = SshMsgExtInfo(ByteBufferKaitaiStream(bodyBytes)) + extMsg._read() + receivedExtInfo.trySend(extMsg) + } + + SshEnums.MessageType.SSH_MSG_USERAUTH_REQUEST -> { + val bodyBytes = rawBytes.copyOfRange(1, rawBytes.size) + val request = SshMsgUserauthRequest(ByteBufferKaitaiStream(bodyBytes)) + request._read() + receivedUserauthRequests.trySend(request) + } + + SshEnums.MessageType.SSH_MSG_PING -> { + val bodyBytes = rawBytes.copyOfRange(1, rawBytes.size) + val pingMsg = SshMsgPing(ByteBufferKaitaiStream(bodyBytes)) + pingMsg._read() + val pong = SshMsgPong() + pong.setData(createByteString(pingMsg.data().data())) + pong._check() + writeMutex.withLock { + serverIo.writePacket(SshEnums.MessageType.SSH_MSG_PONG.id().toInt(), pong.toByteArray()) + } + } + + SshEnums.MessageType.SSH_MSG_PONG -> { + val bodyBytes = rawBytes.copyOfRange(1, rawBytes.size) + val pongMsg = SshMsgPong(ByteBufferKaitaiStream(bodyBytes)) + pongMsg._read() + receivedPongs.trySend(pongMsg.data().data()) + } + else -> { /* ignore */ } } false @@ -161,7 +232,7 @@ class FakeSshServer( val clientPublic = parseEcdhInit(ecdhInitRaw) sendEcdhReply(io, clientKexInitRaw, serverKexInitBytes, clientPublic) writeMutex.withLock { io.writePacket(SshEnums.MessageType.SSH_MSG_NEWKEYS.id().toInt()) } - readPacketRaw(io) // client NEWKEYS + readPacketFiltering(io) // client NEWKEYS activateEncryption(io) } @@ -189,8 +260,7 @@ class FakeSshServer( val clientPublic = parseEcdhInit(ecdhInitRaw) sendEcdhReply(io, clientKexInitRaw, serverKexInitBytes, clientPublic) writeMutex.withLock { io.writePacket(SshEnums.MessageType.SSH_MSG_NEWKEYS.id().toInt()) } - packets.receive() // client NEWKEYS - activateEncryption(io) + packets.receive() // client NEWKEYS (encryption already activated by reader) _rekeyCount.update { it + 1 } } @@ -204,17 +274,22 @@ class FakeSshServer( val clientPublic = parseEcdhInit(ecdhInitRaw) sendEcdhReply(io, clientKexInitRaw, serverKexInitBytes, clientPublic) writeMutex.withLock { io.writePacket(SshEnums.MessageType.SSH_MSG_NEWKEYS.id().toInt()) } - packets.receive() // client NEWKEYS - activateEncryption(io) + packets.receive() // client NEWKEYS (encryption already activated by reader) _rekeyCount.update { it + 1 } } private suspend fun sendKexInit(io: PacketIO): ByteArray { val cookie = ByteArray(16).also { SecureRandom().nextBytes(it) } + val kexAlgs = kexAlgorithms ?: if (advertiseExtInfo) { + "curve25519-sha256,ext-info-s" + } else { + "curve25519-sha256" + } + val kexInit = SshMsgKexinit().apply { setCookie(cookie) - setKexAlgorithms(createNameList("curve25519-sha256")) + setKexAlgorithms(createNameList(kexAlgs)) setServerHostKeyAlgorithms(createNameList("ssh-ed25519")) setEncryptionAlgorithmsClientToServer(createNameList("aes128-ctr")) setEncryptionAlgorithmsServerToClient(createNameList("aes128-ctr")) @@ -366,6 +441,42 @@ class FakeSshServer( return out.toByteArray() } + suspend fun sendCustomExtInfo(extensions: Map) { + val msg = SshMsgExtInfo() + msg.setNumExtensions(extensions.size.toLong()) + val extList = extensions.map { (name, value) -> + SshMsgExtInfo.Extension().apply { + set_root(msg) + set_parent(msg) + setExtensionName(createAsciiString(name)) + setExtensionValue(createByteString(value)) + _check() + } + } + msg.setExtensions(ArrayList(extList)) + msg._check() + writeMutex.withLock { + serverIo.writePacket(SshEnums.MessageType.SSH_MSG_EXT_INFO.id().toInt(), msg.toByteArray()) + } + } + + private suspend fun sendExtInfo(io: PacketIO) { + if (!advertiseExtInfo || !advertisePing) return + val msg = SshMsgExtInfo() + msg.setNumExtensions(1L) + val ext = SshMsgExtInfo.Extension() + ext.set_root(msg) + ext.set_parent(msg) + ext.setExtensionName(createAsciiString("ping@openssh.com")) + ext.setExtensionValue(createByteString("0".toByteArray(Charsets.US_ASCII))) + ext._check() + msg.setExtensions(arrayListOf(ext)) + msg._check() + writeMutex.withLock { + io.writePacket(SshEnums.MessageType.SSH_MSG_EXT_INFO.id().toInt(), msg.toByteArray()) + } + } + private suspend fun sendServiceAccept(io: PacketIO) { val msg = SshMsgServiceAccept().apply { setServiceName(createAsciiString("ssh-userauth")) @@ -379,6 +490,28 @@ class FakeSshServer( return byteArrayOf(packet.messageType().id().toByte()) + packet._raw_body() } + private suspend fun readPacketFiltering(io: PacketIO): Pair { + while (true) { + val packet = io.readPacket() + val msgType = packet.messageType() + val rawBytes = byteArrayOf(msgType.id().toByte()) + packet._raw_body() + when (msgType) { + SshEnums.MessageType.SSH_MSG_EXT_INFO -> { + val bodyBytes = rawBytes.copyOfRange(1, rawBytes.size) + val extMsg = SshMsgExtInfo(ByteBufferKaitaiStream(bodyBytes)) + extMsg._read() + receivedExtInfo.trySend(extMsg) + } + + SshEnums.MessageType.SSH_MSG_DEBUG, + SshEnums.MessageType.SSH_MSG_IGNORE, + -> { /* skip */ } + + else -> return msgType to rawBytes + } + } + } + private suspend fun readPacketWithType(io: PacketIO): Pair { val packet = io.readPacket() val msgType = packet.messageType() @@ -393,4 +526,21 @@ class FakeSshServer( msg._read() return msg.qC().data() } + + fun sendServerPing(data: ByteArray) { + scope.launch(coroutineContext) { + val ping = SshMsgPing() + ping.setData(createByteString(data)) + ping._check() + writeMutex.withLock { + serverIo.writePacket(SshEnums.MessageType.SSH_MSG_PING.id().toInt(), ping.toByteArray()) + } + } + } + + suspend fun awaitPong(): ByteArray = receivedPongs.receive() + + suspend fun awaitExtInfo(): SshMsgExtInfo = receivedExtInfo.receive() + + suspend fun awaitUserauthRequest(): SshMsgUserauthRequest = receivedUserauthRequests.receive() } diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/PingConnectionTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/PingConnectionTest.kt new file mode 100644 index 00000000..4f568877 --- /dev/null +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/PingConnectionTest.kt @@ -0,0 +1,209 @@ +package org.connectbot.sshlib.client + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.coroutines.yield +import org.connectbot.sshlib.ConnectResult +import org.connectbot.sshlib.HostKeyVerifier +import org.connectbot.sshlib.PingResult +import org.connectbot.sshlib.PublicKey +import org.connectbot.sshlib.transport.PipedTransport +import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import kotlin.test.assertIs + +@OptIn(ExperimentalCoroutinesApi::class) +class PingConnectionTest { + + private val acceptAllVerifier = object : HostKeyVerifier { + override suspend fun verify(key: PublicKey): Boolean = true + } + + private suspend fun connectInBackground( + connection: SshConnection, + backgroundScope: CoroutineScope, + dispatcher: CoroutineDispatcher, + ): ConnectResult { + val result = CompletableDeferred() + backgroundScope.launch(dispatcher) { result.complete(connection.connect()) } + yield() + return result.await() + } + + @Test + fun `ping returns NotSupported when server does not advertise ping`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertisePing = false + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + rekeyIntervalMs = Long.MAX_VALUE, + rekeyBytesLimit = Long.MAX_VALUE, + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + val pingResult = connection.ping() + assertIs(pingResult) + } finally { + connection.close() + } + } + + @Test + fun `ping returns Success when server advertises ping`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertiseExtInfo = true + server.advertisePing = true + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + rekeyIntervalMs = Long.MAX_VALUE, + rekeyBytesLimit = Long.MAX_VALUE, + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + val pingResult = connection.ping() + assertIs(pingResult) + assertTrue(pingResult.elapsedNs >= 0) + } finally { + connection.close() + } + } + + @Test + fun `client responds to server ping with pong containing same data`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertiseExtInfo = true + server.advertisePing = true + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + rekeyIntervalMs = Long.MAX_VALUE, + rekeyBytesLimit = Long.MAX_VALUE, + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + val pingData = byteArrayOf(0x01, 0x02, 0x03, 0x04) + server.sendServerPing(pingData) + + val pongData = withTimeoutOrNull(5_000) { server.awaitPong() } + assertNotNull(pongData, "Expected pong within timeout") + assertArrayEquals(pingData, pongData) + } finally { + connection.close() + } + } + + @Test + fun `ping queued during rekey completes after rekey finishes`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertiseExtInfo = true + server.advertisePing = true + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + rekeyIntervalMs = Long.MAX_VALUE, + rekeyBytesLimit = Long.MAX_VALUE, + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + server.initiateRekey() + while (!connection.isRekeying) { + yield() + } + + val pingDeferred = CompletableDeferred() + backgroundScope.launch(dispatcher) { + pingDeferred.complete(connection.ping()) + } + + server.rekeyCount.first { it >= 1 } + + val pingResult = withTimeoutOrNull(5_000) { pingDeferred.await() } + assertNotNull(pingResult, "Ping should resolve after rekey") + assertIs(pingResult) + } finally { + connection.close() + } + } + + @Test + fun `ping fails when connection is closed while pending`() = runTest { + val dispatcher = StandardTestDispatcher(testScheduler) + val (clientTransport, serverTransport) = PipedTransport.create() + val server = FakeSshServer(serverTransport, backgroundScope, dispatcher) + server.advertiseExtInfo = true + server.advertisePing = true + server.start() + + val connection = SshConnection( + transport = clientTransport, + hostKeyVerifier = acceptAllVerifier, + coroutineDispatcher = dispatcher, + ) + + try { + val result = connectInBackground(connection, backgroundScope, dispatcher) + assertIs(result) + + val pingDeferred = CompletableDeferred() + backgroundScope.launch(dispatcher) { + pingDeferred.complete(connection.ping()) + } + yield() + + connection.close() + + val pingResult = withTimeoutOrNull(5_000) { + pingDeferred.await() + } + assertNotNull(pingResult, "Ping should resolve after connection close") + assertIs(pingResult) + } finally { + connection.close() + } + } +} diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/SshClientIntegrationTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/SshClientIntegrationTest.kt index 9fa220d2..f11b544e 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/SshClientIntegrationTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/SshClientIntegrationTest.kt @@ -27,6 +27,7 @@ import org.connectbot.sshlib.AuthResult import org.connectbot.sshlib.ConnectResult import org.connectbot.sshlib.HostKeyVerifier import org.connectbot.sshlib.KeyboardInteractiveCallback +import org.connectbot.sshlib.PingResult import org.connectbot.sshlib.PublicKey import org.connectbot.sshlib.SshClient import org.connectbot.sshlib.SshClientConfig @@ -50,6 +51,7 @@ import org.testcontainers.containers.wait.strategy.Wait import org.testcontainers.images.builder.ImageFromDockerfile import org.testcontainers.junit.jupiter.Container import org.testcontainers.junit.jupiter.Testcontainers +import kotlin.test.assertIs /** * Integration tests for SSH client using testcontainers with real SSH servers. @@ -167,6 +169,60 @@ class SshClientIntegrationTest { } } + @Test + fun `ping returns success against real OpenSSH server`() { + val host = opensshContainer.host + val port = opensshContainer.getMappedPort(22) + + runBlocking { + val client = SshClient( + SshClientConfig { + this.host = host + this.port = port + this.hostKeyVerifier = acceptAllVerifier + }, + ) + try { + val connectResult = client.connect() + assertIs(connectResult) + + val authResult = client.authenticatePassword(USERNAME, PASSWORD) + assertIs(authResult) + + val pingResult = client.ping() + assertIs(pingResult) + assertTrue(pingResult.elapsedNs > 0, "Elapsed time should be positive") + } finally { + client.disconnect() + } + } + } + + @Test + fun `ping returns NotAuthenticated before authentication against real OpenSSH server`() { + val host = opensshContainer.host + val port = opensshContainer.getMappedPort(22) + + runBlocking { + val client = SshClient( + SshClientConfig { + this.host = host + this.port = port + this.hostKeyVerifier = acceptAllVerifier + }, + ) + try { + val connectResult = client.connect() + assertIs(connectResult) + + val pingResult = client.ping() + assertIs(pingResult) + } finally { + client.disconnect() + } + } + } + @Test fun `should authenticate with password`() { val host = opensshContainer.host @@ -444,7 +500,7 @@ class SshClientIntegrationTest { this.hostKeyVerifier = acceptAllVerifier // Force ssh-rsa only by excluding rsa-sha2 variants this.hostKeyAlgorithms = "ssh-rsa" - // Use a KEX algorithm that doesn't include ext-info-c by default + // Use a single KEX algorithm; SshConnection appends ext-info-c automatically. this.kexAlgorithms = "diffie-hellman-group14-sha256" }, )