From b60ac26b3e5d2f52a9c365de72bbd02b5e1fbb9d Mon Sep 17 00:00:00 2001 From: Kenny Root Date: Tue, 28 Apr 2026 21:47:40 -0700 Subject: [PATCH 1/2] feat: implement ssh-agent restriction Implements OpenSSH 8.9+ agent restriction protocol per https://www.openssh.org/agent-restrict.html Add support for publickey-hostbound-v00@openssh.com as well as agent key use constraints. Start advertising "ext-info-c" in the kex. --- .../restrict_destination_constraint.ksy | 44 +++ .../kaitai/ssh_msg_userauth_request.ksy | 2 + ...uth_publickey_hostbound_signature_data.ksy | 40 ++ .../userauth_publickey_signature_data_any.ksy | 40 ++ .../userauth_request_publickey_hostbound.ksy | 31 ++ sshlib/api.txt | 44 ++- .../org/connectbot/sshlib/AgentProvider.kt | 51 +++ .../sshlib/client/AgentProtocolHandler.kt | 204 +++++++++- .../connectbot/sshlib/client/SshConnection.kt | 130 +++++-- .../connectbot/sshlib/crypto/Algorithms.kt | 2 +- .../sshlib/AgentDestinationConstraintTest.kt | 363 ++++++++++++++++++ .../connectbot/sshlib/AgentProtocolTest.kt | 116 +++--- .../sshlib/ExtInfoProcessingTest.kt | 102 +++++ .../sshlib/HostBoundSignatureDataTest.kt | 111 ++++++ 14 files changed, 1178 insertions(+), 102 deletions(-) create mode 100644 protocol/src/main/resources/kaitai/restrict_destination_constraint.ksy create mode 100644 protocol/src/main/resources/kaitai/userauth_publickey_hostbound_signature_data.ksy create mode 100644 protocol/src/main/resources/kaitai/userauth_publickey_signature_data_any.ksy create mode 100644 protocol/src/main/resources/kaitai/userauth_request_publickey_hostbound.ksy create mode 100644 sshlib/src/test/kotlin/org/connectbot/sshlib/AgentDestinationConstraintTest.kt create mode 100644 sshlib/src/test/kotlin/org/connectbot/sshlib/ExtInfoProcessingTest.kt create mode 100644 sshlib/src/test/kotlin/org/connectbot/sshlib/HostBoundSignatureDataTest.kt diff --git a/protocol/src/main/resources/kaitai/restrict_destination_constraint.ksy b/protocol/src/main/resources/kaitai/restrict_destination_constraint.ksy new file mode 100644 index 0000000..ee98526 --- /dev/null +++ b/protocol/src/main/resources/kaitai/restrict_destination_constraint.ksy @@ -0,0 +1,44 @@ +meta: + id: restrict_destination_constraint + title: restrict-destination-v00@openssh.com Constraint + endian: be + imports: + - byte_string +doc: | + One hop entry for the restrict-destination-v00@openssh.com agent key constraint. + from_hostname and from_keyspecs are empty for origin-side constraints. +seq: + - id: from_hostname + type: byte_string + doc: Hostname of the previous hop (empty for origin) + - id: num_from_keyspecs + type: u4 + doc: Number of from host key specs + - id: from_keyspecs + type: keyspec + repeat: expr + repeat-expr: num_from_keyspecs + doc: Host key specs for the previous hop (empty for origin) + - id: to_username + type: byte_string + doc: Destination username (empty = any user) + - id: to_hostname + type: byte_string + doc: Destination hostname + - id: num_to_hostspecs + type: u4 + doc: Number of destination host key specs + - id: to_hostspecs + type: keyspec + repeat: expr + repeat-expr: num_to_hostspecs + doc: Destination host key specs +types: + keyspec: + seq: + - id: keyblob + type: byte_string + doc: Host key blob + - id: is_ca + type: u1 + doc: 1 if this is a CA key, 0 if it is a direct host key diff --git a/protocol/src/main/resources/kaitai/ssh_msg_userauth_request.ksy b/protocol/src/main/resources/kaitai/ssh_msg_userauth_request.ksy index 435dbdb..b053a48 100644 --- a/protocol/src/main/resources/kaitai/ssh_msg_userauth_request.ksy +++ b/protocol/src/main/resources/kaitai/ssh_msg_userauth_request.ksy @@ -13,6 +13,7 @@ meta: - userauth_request_none - userauth_request_password - userauth_request_publickey + - userauth_request_publickey_hostbound doc-ref: RFC 4252 section 5 seq: - id: user_name @@ -26,6 +27,7 @@ seq: switch-on: method_name.value cases: '"publickey"': userauth_request_publickey + '"publickey-hostbound-v00@openssh.com"': userauth_request_publickey_hostbound '"password"': userauth_request_password '"hostbased"': userauth_request_hostbased '"none"': userauth_request_none diff --git a/protocol/src/main/resources/kaitai/userauth_publickey_hostbound_signature_data.ksy b/protocol/src/main/resources/kaitai/userauth_publickey_hostbound_signature_data.ksy new file mode 100644 index 0000000..5aae4eb --- /dev/null +++ b/protocol/src/main/resources/kaitai/userauth_publickey_hostbound_signature_data.ksy @@ -0,0 +1,40 @@ +meta: + id: userauth_publickey_hostbound_signature_data + endian: be + imports: + - byte_string +doc-ref: OpenSSH publickey-hostbound-v00@openssh.com extension +doc: > + The data over which the signature is computed for publickey-hostbound authentication. + Identical to userauth_publickey_signature_data but method_name is + "publickey-hostbound-v00@openssh.com" and server_host_key is appended. +seq: +- id: session_identifier + type: byte_string + doc: Session identifier from key exchange +- id: message_type + contents: + - 50 + doc: SSH_MSG_USERAUTH_REQUEST (50) +- id: user_name + type: byte_string + doc: User name +- id: service_name + type: byte_string + doc: Service name +- id: method_name + type: byte_string + doc: Authentication method name ("publickey-hostbound-v00@openssh.com") +- id: has_signature + contents: + - 1 + doc: TRUE (1) +- id: public_key_algorithm_name + type: byte_string + doc: Public key algorithm name +- id: public_key_blob + type: byte_string + doc: Public key to be used for authentication +- id: server_host_key + type: byte_string + doc: Server's host key, binding the signature to the intended destination diff --git a/protocol/src/main/resources/kaitai/userauth_publickey_signature_data_any.ksy b/protocol/src/main/resources/kaitai/userauth_publickey_signature_data_any.ksy new file mode 100644 index 0000000..891affa --- /dev/null +++ b/protocol/src/main/resources/kaitai/userauth_publickey_signature_data_any.ksy @@ -0,0 +1,40 @@ +meta: + id: userauth_publickey_signature_data_any + endian: be + imports: + - byte_string + - ascii_string + - utf8_string +doc: > + The data over which the signature is computed for publickey authentication. + This handles both the standard RFC 4252 version and the OpenSSH + publickey-hostbound-v00@openssh.com extension. +seq: + - id: session_identifier + type: byte_string + doc: Session identifier from key exchange + - id: message_type + contents: [50] + doc: SSH_MSG_USERAUTH_REQUEST (50) + - id: user_name + type: utf8_string + doc: User name + - id: service_name + type: ascii_string + doc: Service name + - id: method_name + type: ascii_string + doc: Authentication method name + - id: has_signature + contents: [1] + doc: TRUE (1) + - id: public_key_algorithm_name + type: ascii_string + doc: Public key algorithm name + - id: public_key_blob + type: byte_string + doc: Public key to be used for authentication + - id: server_host_key + type: byte_string + doc: Server's host key, binding the signature to the intended destination (only for hostbound method) + if: method_name.value == "publickey-hostbound-v00@openssh.com" diff --git a/protocol/src/main/resources/kaitai/userauth_request_publickey_hostbound.ksy b/protocol/src/main/resources/kaitai/userauth_request_publickey_hostbound.ksy new file mode 100644 index 0000000..ac9d39c --- /dev/null +++ b/protocol/src/main/resources/kaitai/userauth_request_publickey_hostbound.ksy @@ -0,0 +1,31 @@ +meta: + id: userauth_request_publickey_hostbound + endian: be + imports: + - ascii_string + - byte_string +doc-ref: OpenSSH publickey-hostbound-v00@openssh.com extension +doc: > + Publickey authentication method fields for publickey-hostbound-v00@openssh.com. + Identical to userauth_request_publickey but adds server_host_key before signature. +seq: +- id: has_signature + type: u1 + doc: > + FALSE (0) to query if the public key is acceptable for authentication. + TRUE (1) to perform actual authentication with signature. +- id: public_key_algorithm_name + type: ascii_string + doc: Public key algorithm name +- id: public_key_blob + type: byte_string + doc: Public key blob (may contain certificates) +- id: server_host_key + type: byte_string + doc: Server's host key blob, binding the authentication to a specific destination +- id: signature + type: byte_string + doc: > + Signature over session identifier and authentication request including server host key. + Only present when has_signature is TRUE. + if: has_signature != 0 diff --git a/sshlib/api.txt b/sshlib/api.txt index b99f493..f9f5288 100644 --- a/sshlib/api.txt +++ b/sshlib/api.txt @@ -2,19 +2,36 @@ package org.connectbot.sshlib { public final class AgentIdentity { - ctor public AgentIdentity(byte[] publicKeyBlob, java.lang.String comment); + ctor public AgentIdentity(byte[] publicKeyBlob, java.lang.String comment, optional java.util.List? destinationConstraints); method public byte[] component1(); method public java.lang.String component2(); - method public org.connectbot.sshlib.AgentIdentity copy(optional byte[] publicKeyBlob, optional java.lang.String comment); + method public java.util.List? component3(); + method public org.connectbot.sshlib.AgentIdentity copy(optional byte[] publicKeyBlob, optional java.lang.String comment, optional java.util.List? destinationConstraints); method public boolean equals(java.lang.Object? other); method @InaccessibleFromKotlin public java.lang.String getComment(); + method @InaccessibleFromKotlin public java.util.List? getDestinationConstraints(); method @InaccessibleFromKotlin public byte[] getPublicKeyBlob(); method public int hashCode(); method public java.lang.String toString(); property public String comment; + property public java.util.List? destinationConstraints; property public byte[] publicKeyBlob; } + public final class AgentKeySpec { + ctor public AgentKeySpec(byte[] keyBlob, boolean isCa); + method public byte[] component1(); + method public boolean component2(); + method public org.connectbot.sshlib.AgentKeySpec copy(optional byte[] keyBlob, optional boolean isCa); + method public boolean equals(java.lang.Object? other); + method @InaccessibleFromKotlin public byte[] getKeyBlob(); + method public int hashCode(); + method @InaccessibleFromKotlin public boolean isCa(); + method public java.lang.String toString(); + property public boolean isCa; + property public byte[] keyBlob; + } + public interface AgentProvider { method public suspend java.lang.Object? getIdentities(kotlin.coroutines.Continuation>); method public suspend java.lang.Object? signData(org.connectbot.sshlib.AgentSigningContext context, kotlin.coroutines.Continuation); @@ -196,6 +213,29 @@ package org.connectbot.sshlib { property public String serverHostKeyAlgorithm; } + public final class DestinationConstraint { + ctor public DestinationConstraint(java.lang.String fromHostname, java.util.List fromKeyspecs, java.lang.String toUsername, java.lang.String toHostname, java.util.List toHostspecs); + method public java.lang.String component1(); + method public java.util.List component2(); + method public java.lang.String component3(); + method public java.lang.String component4(); + method public java.util.List component5(); + method public org.connectbot.sshlib.DestinationConstraint copy(optional java.lang.String fromHostname, optional java.util.List fromKeyspecs, optional java.lang.String toUsername, optional java.lang.String toHostname, optional java.util.List toHostspecs); + method public boolean equals(java.lang.Object? other); + method @InaccessibleFromKotlin public java.lang.String getFromHostname(); + method @InaccessibleFromKotlin public java.util.List getFromKeyspecs(); + method @InaccessibleFromKotlin public java.lang.String getToHostname(); + method @InaccessibleFromKotlin public java.util.List getToHostspecs(); + method @InaccessibleFromKotlin public java.lang.String getToUsername(); + method public int hashCode(); + method public java.lang.String toString(); + property public String fromHostname; + property public java.util.List fromKeyspecs; + property public String toHostname; + property public java.util.List toHostspecs; + property public String toUsername; + } + public interface HostKeyVerifier { method public default suspend java.lang.Object? addKeys(java.util.List keys, kotlin.coroutines.Continuation); method public default suspend java.lang.Object? removeKeys(java.util.List keys, kotlin.coroutines.Continuation); diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/AgentProvider.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/AgentProvider.kt index 1176f8c..2d29bee 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/AgentProvider.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/AgentProvider.kt @@ -16,6 +16,53 @@ package org.connectbot.sshlib +/** + * A host key specification used in destination constraints. + * + * @param keyBlob Wire-format public key blob + * @param isCa True if this is a CA key that signed the destination's host certificate + */ +data class AgentKeySpec( + val keyBlob: ByteArray, + val isCa: Boolean, +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + other as AgentKeySpec + if (!keyBlob.contentEquals(other.keyBlob)) return false + if (isCa != other.isCa) return false + return true + } + + override fun hashCode(): Int { + var result = keyBlob.contentHashCode() + result = 31 * result + isCa.hashCode() + return result + } +} + +/** + * One hop entry in a destination constraint for a key in the agent. + * + * Describes a single permitted (from → to) hop in a forwarding chain. + * When [fromHostname] and [fromKeyspecs] are empty this is an origin-direct constraint + * (key may be used directly from the machine running the agent). + * + * @param fromHostname Hostname of the previous hop, empty string for the origin machine + * @param fromKeyspecs Host key specs of the previous hop, empty for the origin machine + * @param toUsername Permitted destination username; empty string means any user is allowed + * @param toHostname Destination hostname + * @param toHostspecs Destination host key specs (must be non-empty) + */ +data class DestinationConstraint( + val fromHostname: String, + val fromKeyspecs: List, + val toUsername: String, + val toHostname: String, + val toHostspecs: List, +) + /** * Provider interface for SSH agent forwarding. * @@ -50,10 +97,12 @@ interface AgentProvider { * * @param publicKeyBlob Wire-format public key blob * @param comment Human-readable comment describing the key + * @param destinationConstraints Optional per-hop destination constraints; null means unconstrained */ data class AgentIdentity( val publicKeyBlob: ByteArray, val comment: String, + val destinationConstraints: List? = null, ) { override fun equals(other: Any?): Boolean { if (this === other) return true @@ -61,12 +110,14 @@ data class AgentIdentity( other as AgentIdentity if (!publicKeyBlob.contentEquals(other.publicKeyBlob)) return false if (comment != other.comment) return false + if (destinationConstraints != other.destinationConstraints) return false return true } override fun hashCode(): Int { var result = publicKeyBlob.contentHashCode() result = 31 * result + comment.hashCode() + result = 31 * result + destinationConstraints.hashCode() return result } } diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/AgentProtocolHandler.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/AgentProtocolHandler.kt index 9c12ff9..fed9549 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/AgentProtocolHandler.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/AgentProtocolHandler.kt @@ -18,21 +18,33 @@ package org.connectbot.sshlib.client import io.kaitai.struct.ByteBufferKaitaiStream import io.kaitai.struct.KaitaiStruct +import org.connectbot.sshlib.AgentIdentity +import org.connectbot.sshlib.AgentKeySpec import org.connectbot.sshlib.AgentProvider import org.connectbot.sshlib.AgentSigningContext +import org.connectbot.sshlib.DestinationConstraint +import org.connectbot.sshlib.crypto.SignatureVerifier import org.connectbot.sshlib.protocol.SshAgentIdentitiesAnswer import org.connectbot.sshlib.protocol.SshAgentMessage import org.connectbot.sshlib.protocol.SshAgentSignResponse import org.connectbot.sshlib.protocol.SshAgentcExtension import org.connectbot.sshlib.protocol.SshAgentcSessionBind import org.connectbot.sshlib.protocol.SshAgentcSignRequest +import org.connectbot.sshlib.protocol.UserauthPublickeySignatureDataAny import org.connectbot.sshlib.protocol.createByteString import org.connectbot.sshlib.protocol.toByteArray import org.slf4j.LoggerFactory +internal fun interface SessionBindVerifier { + fun verify(hostKeyBlob: ByteArray, signature: ByteArray, data: ByteArray): Boolean +} + internal class AgentProtocolHandler( private val provider: AgentProvider, private val sessionInfo: AgentSessionInfo, + private val bindVerifier: SessionBindVerifier = SessionBindVerifier { hk, sig, data -> + SignatureVerifier.verify(hk, sig, data) + }, ) { companion object { private val logger = LoggerFactory.getLogger(AgentProtocolHandler::class.java) @@ -44,10 +56,29 @@ internal class AgentProtocolHandler( const val SSH_AGENT_FAILURE: Int = 5 const val SSH_AGENTC_EXTENSION: Int = 27 const val SSH_AGENT_SUCCESS: Int = 6 + + private const val METHOD_PUBLICKEY = "publickey" + private const val METHOD_PUBLICKEY_HOSTBOUND = "publickey-hostbound-v00@openssh.com" + } + + private data class BindingEntry(val hostKeyBlob: ByteArray, val sessionId: ByteArray) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + other as BindingEntry + if (!hostKeyBlob.contentEquals(other.hostKeyBlob)) return false + if (!sessionId.contentEquals(other.sessionId)) return false + return true + } + + override fun hashCode(): Int { + var result = hostKeyBlob.contentHashCode() + result = 31 * result + sessionId.contentHashCode() + return result + } } - private var sessionBound = false - private var boundSessionId: ByteArray? = null + private val bindingList: MutableList = mutableListOf() suspend fun handleRequest(requestBytes: ByteArray): ByteArray { logger.debug("Handling agent request (${requestBytes.size} bytes)") @@ -76,14 +107,15 @@ internal class AgentProtocolHandler( private suspend fun handleRequestIdentities(): ByteArray { logger.debug("Handling REQUEST_IDENTITIES") - val identities = provider.getIdentities() - logger.debug("Provider returned ${identities.size} identities") + val allIdentities = provider.getIdentities() + val visibleIdentities = filterVisibleIdentities(allIdentities) + logger.debug("Provider returned ${allIdentities.size} identities, ${visibleIdentities.size} visible for current path") val response = SshAgentIdentitiesAnswer() - response.setNkeys(identities.size.toLong()) + response.setNkeys(visibleIdentities.size.toLong()) val identityList = ArrayList() - for (identity in identities) { + for (identity in visibleIdentities) { val id = SshAgentIdentitiesAnswer.Identity() id.set_root(response) id.set_parent(response) @@ -98,21 +130,67 @@ internal class AgentProtocolHandler( return buildAgentMessage(SSH_AGENT_IDENTITIES_ANSWER, response.toByteArray()) } + private fun filterVisibleIdentities(identities: List): List { + if (bindingList.isEmpty()) return identities + val lastHopKey = bindingList.last().hostKeyBlob + return identities.filter { identity -> + val constraints = identity.destinationConstraints + if (constraints == null) return@filter true + constraints.any { c -> + c.fromKeyspecs.isNotEmpty() && c.fromKeyspecs.any { spec -> + spec.keyBlob.contentEquals(lastHopKey) + } + } + } + } + private suspend fun handleSignRequest(message: SshAgentMessage): ByteArray { logger.debug("Handling SIGN_REQUEST") val payload = parsePayload(message) + val keyBlob = payload.keyBlob().data() + val dataToSign = payload.data().data() + + val identity = provider.getIdentities().find { it.publicKeyBlob.contentEquals(keyBlob) } + val constraints = identity?.destinationConstraints + + if (constraints != null) { + var components = parseSignedDataComponents(dataToSign) + if (components == null) { + logger.warn("Failed to parse signed data for constraint check") + return createFailureResponse() + } + + if (bindingList.isNotEmpty() && components.methodName != METHOD_PUBLICKEY_HOSTBOUND) { + logger.warn("Forwarded connection requires publickey-hostbound method, got: ${components.methodName}") + return createFailureResponse() + } + + // For direct connections with standard publickey auth (no embedded server host key), + // use the session's server host key as the implicit destination. + if (bindingList.isEmpty() && components.serverHostKeyBlob == null) { + components = components.copy(serverHostKeyBlob = sessionInfo.serverHostKey) + } + + if (!isConstraintSatisfied(constraints, components)) { + logger.warn("Destination constraint not satisfied for key") + return createFailureResponse() + } + } + + val isBound = bindingList.isNotEmpty() + val effectiveSessionId = bindingList.lastOrNull()?.sessionId ?: sessionInfo.sessionId val context = AgentSigningContext( - publicKeyBlob = payload.keyBlob().data(), - dataToSign = payload.data().data(), + publicKeyBlob = keyBlob, + dataToSign = dataToSign, flags = payload.flags().toInt(), - sessionId = boundSessionId ?: sessionInfo.sessionId, + sessionId = effectiveSessionId, serverHostKey = sessionInfo.serverHostKey, - isBound = sessionBound, + isBound = isBound, ) - logger.debug("Requesting signature from provider (bound=$sessionBound, flags=${context.flags})") + logger.debug("Requesting signature from provider (bound=$isBound, flags=${context.flags})") val signature = provider.signData(context) @@ -128,6 +206,86 @@ internal class AgentProtocolHandler( } } + private data class SignedDataComponents( + val methodName: String, + val destUsername: String, + val serverHostKeyBlob: ByteArray?, + ) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + other as SignedDataComponents + if (methodName != other.methodName) return false + if (destUsername != other.destUsername) return false + if (!serverHostKeyBlob.contentEquals(other.serverHostKeyBlob)) return false + return true + } + + override fun hashCode(): Int { + var result = methodName.hashCode() + result = 31 * result + destUsername.hashCode() + result = 31 * result + serverHostKeyBlob.contentHashCode() + return result + } + } + + private fun ByteArray?.contentEquals(other: ByteArray?): Boolean { + if (this == null && other == null) return true + if (this == null || other == null) return false + return java.util.Arrays.equals(this, other) + } + + private fun ByteArray?.contentHashCode(): Int = this?.contentHashCode() ?: 0 + + private fun parseSignedDataComponents(data: ByteArray): SignedDataComponents? = try { + val stream = ByteBufferKaitaiStream(data) + val sigData = UserauthPublickeySignatureDataAny(stream) + sigData._read() + + SignedDataComponents( + methodName = sigData.methodName().value(), + destUsername = sigData.userName().value(), + serverHostKeyBlob = sigData.serverHostKey()?.data(), + ) + } catch (e: Exception) { + logger.debug("Failed to parse signed data components: ${e.message}") + null + } + + private fun isConstraintSatisfied( + constraints: List, + components: SignedDataComponents, + ): Boolean { + val isForwarding = bindingList.isNotEmpty() + // The forwarding hop key is the second-to-last binding (the host that relayed to us). + // The last binding is the destination's connection key, also represented in components.serverHostKeyBlob. + val forwardingHopKey = if (bindingList.size >= 2) { + bindingList[bindingList.size - 2].hostKeyBlob + } else if (bindingList.size == 1) { + bindingList[0].hostKeyBlob + } else { + null + } + + return constraints.any { c -> + val fromMatches = if (!isForwarding) { + c.fromHostname.isEmpty() && c.fromKeyspecs.isEmpty() + } else { + forwardingHopKey != null && c.fromKeyspecs.any { spec -> + spec.keyBlob.contentEquals(forwardingHopKey) + } + } + if (!fromMatches) return@any false + + val usernameMatches = c.toUsername.isEmpty() || c.toUsername == components.destUsername + if (!usernameMatches) return@any false + + val hostKeyMatches = components.serverHostKeyBlob != null && + c.toHostspecs.any { spec -> spec.keyBlob.contentEquals(components.serverHostKeyBlob) } + hostKeyMatches + } + } + private suspend fun handleExtension(message: SshAgentMessage): ByteArray { logger.debug("Handling EXTENSION") @@ -151,20 +309,30 @@ internal class AgentProtocolHandler( val bind = SshAgentcSessionBind(stream) bind._read() - if (sessionBound) { - logger.warn("Session already bound, rejecting duplicate binding") + val hostKeyBlob = bind.hostkey().data() + val sessionId = bind.sessionIdentifier().data() + val isForwarding = bind.isForwarding().toInt() != 0 + + // Replay protection: reject duplicate session IDs + if (bindingList.any { it.sessionId.contentEquals(sessionId) }) { + logger.warn("Session bind replay: session ID already recorded") return createFailureResponse() } - if (!bind.hostkey().data().contentEquals(sessionInfo.serverHostKey)) { - logger.error("Session bind hostkey mismatch") + // For non-forwarding (origin) binds, verify the hostkey matches the connection's server key + if (!isForwarding && !hostKeyBlob.contentEquals(sessionInfo.serverHostKey)) { + logger.error("Session bind hostkey mismatch for non-forwarding bind") return createFailureResponse() } - sessionBound = true - boundSessionId = bind.sessionIdentifier().data() + // Cryptographically verify the session bind signature + if (!bindVerifier.verify(hostKeyBlob, bind.signature().data(), sessionId)) { + logger.error("Session bind signature verification failed") + return createFailureResponse() + } - logger.info("Session binding successful") + bindingList.add(BindingEntry(hostKeyBlob, sessionId)) + logger.info("Session binding successful (isForwarding=$isForwarding, total bindings=${bindingList.size})") return createSuccessResponse() } diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt index 3fde5a7..07b2c9a 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt @@ -88,6 +88,7 @@ import org.connectbot.sshlib.protocol.SshMsgChannelRequest import org.connectbot.sshlib.protocol.SshMsgChannelWindowAdjust import org.connectbot.sshlib.protocol.SshMsgDebug import org.connectbot.sshlib.protocol.SshMsgDisconnect +import org.connectbot.sshlib.protocol.SshMsgExtInfo import org.connectbot.sshlib.protocol.SshMsgGlobalRequest import org.connectbot.sshlib.protocol.SshMsgKexDhGexGroup import org.connectbot.sshlib.protocol.SshMsgKexDhGexInit @@ -107,11 +108,13 @@ import org.connectbot.sshlib.protocol.SshMsgUserauthInfoResponse import org.connectbot.sshlib.protocol.SshMsgUserauthPkOk import org.connectbot.sshlib.protocol.SshMsgUserauthRequest import org.connectbot.sshlib.protocol.UnencryptedPacket +import org.connectbot.sshlib.protocol.UserauthPublickeyHostboundSignatureData import org.connectbot.sshlib.protocol.UserauthPublickeySignatureData import org.connectbot.sshlib.protocol.UserauthRequestKeyboardInteractive import org.connectbot.sshlib.protocol.UserauthRequestNone import org.connectbot.sshlib.protocol.UserauthRequestPassword import org.connectbot.sshlib.protocol.UserauthRequestPublickey +import org.connectbot.sshlib.protocol.UserauthRequestPublickeyHostbound import org.connectbot.sshlib.protocol.createAsciiString import org.connectbot.sshlib.protocol.createByteString import org.connectbot.sshlib.protocol.createMpint @@ -239,6 +242,7 @@ class SshConnection( private var agentProvider: AgentProvider? = null private var serverHostKeyBlob: ByteArray? = null + private var serverAdvertisesHostBound: Boolean = false /** * Helper to manage a pending asynchronous operation that waits for a server response. @@ -546,37 +550,47 @@ class SshConnection( val sigEntry = SignatureEntry.fromSshName(sigAlgorithmName) ?: throw SshException("Unknown signature algorithm: $sigAlgorithmName") - // Build the data to sign per RFC 4252 §7 - val signatureData = buildSignatureData( - sid, - username, - "ssh-connection", - sigAlgorithmName, - publicKeyBlob, - ) + val hostKeyBlob = serverHostKeyBlob + val useHostBound = serverAdvertisesHostBound && hostKeyBlob != null + + val signatureData = if (useHostBound && hostKeyBlob != null) { + buildHostBoundSignatureData(sid, username, "ssh-connection", sigAlgorithmName, publicKeyBlob, hostKeyBlob) + } else { + buildSignatureData(sid, username, "ssh-connection", sigAlgorithmName, publicKeyBlob) + } - // Sign the data val signature = sigEntry.algorithm.sign( sigAlgorithmName, privateKey.jcaKeyPair.private, signatureData, ) - // Build the SSH_MSG_USERAUTH_REQUEST val req = SshMsgUserauthRequest().apply { setUserName(createAsciiString(username)) setServiceName(createAsciiString("ssh-connection")) - setMethodName(createAsciiString("publickey")) - val pubkeyAuth = UserauthRequestPublickey().apply { - setHasSignature(1) - setPublicKeyAlgorithmName(createAsciiString(sigAlgorithmName)) - setPublicKeyBlob(createByteString(publicKeyBlob)) - setSignature(createByteString(signature)) - _check() + if (useHostBound && hostKeyBlob != null) { + setMethodName(createAsciiString("publickey-hostbound-v00@openssh.com")) + val pubkeyAuth = UserauthRequestPublickeyHostbound().apply { + setHasSignature(1) + setPublicKeyAlgorithmName(createAsciiString(sigAlgorithmName)) + setPublicKeyBlob(createByteString(publicKeyBlob)) + setServerHostKey(createByteString(hostKeyBlob)) + setSignature(createByteString(signature)) + _check() + } + setMethodSpecificFields(pubkeyAuth) + } else { + setMethodName(createAsciiString("publickey")) + val pubkeyAuth = UserauthRequestPublickey().apply { + setHasSignature(1) + setPublicKeyAlgorithmName(createAsciiString(sigAlgorithmName)) + setPublicKeyBlob(createByteString(publicKeyBlob)) + setSignature(createByteString(signature)) + _check() + } + setMethodSpecificFields(pubkeyAuth) } - - setMethodSpecificFields(pubkeyAuth) _check() } @@ -625,6 +639,29 @@ class SshConnection( return data.toByteArray() } + private fun buildHostBoundSignatureData( + sessionId: ByteArray, + username: String, + serviceName: String, + algorithmName: String, + publicKeyBlob: ByteArray, + serverHostKeyBlob: ByteArray, + ): ByteArray { + val data = UserauthPublickeyHostboundSignatureData().apply { + setSessionIdentifier(createByteString(sessionId)) + setMessageType(byteArrayOf(50)) + setUserName(createByteString(username.toByteArray(Charsets.UTF_8))) + setServiceName(createByteString(serviceName.toByteArray(Charsets.US_ASCII))) + setMethodName(createByteString("publickey-hostbound-v00@openssh.com".toByteArray(Charsets.US_ASCII))) + setHasSignature(byteArrayOf(1)) + setPublicKeyAlgorithmName(createByteString(algorithmName.toByteArray(Charsets.US_ASCII))) + setPublicKeyBlob(createByteString(publicKeyBlob)) + setServerHostKey(createByteString(serverHostKeyBlob)) + _check() + } + return data.toByteArray() + } + /** * Authenticate using the strategy-based [AuthHandler] flow. * @@ -738,25 +775,40 @@ class SshConnection( channel: Channel, ): Boolean { val sid = sessionId ?: throw SshException("Session ID not established") - val signatureData = buildSignatureData( - sid, - username, - "ssh-connection", - key.algorithmName, - key.publicKeyBlob, - ) + val hostKeyBlob = serverHostKeyBlob + val useHostBound = serverAdvertisesHostBound && hostKeyBlob != null + + val signatureData = if (useHostBound && hostKeyBlob != null) { + buildHostBoundSignatureData(sid, username, "ssh-connection", key.algorithmName, key.publicKeyBlob, hostKeyBlob) + } else { + buildSignatureData(sid, username, "ssh-connection", key.algorithmName, key.publicKeyBlob) + } val signature = handler.onSignatureRequest(key, signatureData) ?: return false - sendAuthRequest(username, "publickey") { - val pubkeyAuth = UserauthRequestPublickey().apply { - setHasSignature(1) - setPublicKeyAlgorithmName(createAsciiString(key.algorithmName)) - setPublicKeyBlob(createByteString(key.publicKeyBlob)) - setSignature(createByteString(signature)) - _check() + if (useHostBound && hostKeyBlob != null) { + sendAuthRequest(username, "publickey-hostbound-v00@openssh.com") { + val pubkeyAuth = UserauthRequestPublickeyHostbound().apply { + setHasSignature(1) + setPublicKeyAlgorithmName(createAsciiString(key.algorithmName)) + setPublicKeyBlob(createByteString(key.publicKeyBlob)) + setServerHostKey(createByteString(hostKeyBlob)) + setSignature(createByteString(signature)) + _check() + } + setMethodSpecificFields(pubkeyAuth) + } + } else { + sendAuthRequest(username, "publickey") { + val pubkeyAuth = UserauthRequestPublickey().apply { + setHasSignature(1) + setPublicKeyAlgorithmName(createAsciiString(key.algorithmName)) + setPublicKeyBlob(createByteString(key.publicKeyBlob)) + setSignature(createByteString(signature)) + _check() + } + setMethodSpecificFields(pubkeyAuth) } - setMethodSpecificFields(pubkeyAuth) } return when (channel.receive()) { @@ -2076,6 +2128,16 @@ class SshConnection( stateMachine.processEvent(SshClientStateMachine.SshEvent.Disconnect) } + SshEnums.MessageType.SSH_MSG_EXT_INFO -> { + val extInfo = parseBody(packet) + for (ext in extInfo.extensions()) { + if (ext.extensionName().value() == "publickey-hostbound@openssh.com") { + serverAdvertisesHostBound = true + logger.info("Server advertises publickey-hostbound@openssh.com") + } + } + } + else -> { // KEX-specific messages 30-49 are not in MessageType enum. // Disambiguate by negotiated KEX type since ECDH reply, DH reply, diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/Algorithms.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/Algorithms.kt index 4f7273b..83f290f 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/Algorithms.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/Algorithms.kt @@ -269,7 +269,7 @@ internal enum class KexEntry( val defaults: List = entries.filter { it != DH_GROUP1_SHA1 } val defaultString: String = - defaults.joinToString(",") { it.sshName } + ",kex-strict-c-v00@openssh.com" + defaults.joinToString(",") { it.sshName } + ",kex-strict-c-v00@openssh.com,ext-info-c" fun fromSshName(name: String): KexEntry? = entries.firstOrNull { it.sshName == name } } diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/AgentDestinationConstraintTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/AgentDestinationConstraintTest.kt new file mode 100644 index 0000000..63f1ce5 --- /dev/null +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/AgentDestinationConstraintTest.kt @@ -0,0 +1,363 @@ +/* + * Copyright 2025 Kenny Root + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.connectbot.sshlib + +import kotlinx.coroutines.test.runTest +import org.connectbot.sshlib.client.AgentProtocolHandler +import org.connectbot.sshlib.client.AgentSessionInfo +import org.connectbot.sshlib.client.SessionBindVerifier +import org.connectbot.sshlib.protocol.SshAgentcSessionBind +import org.connectbot.sshlib.protocol.SshAgentcSignRequest +import org.connectbot.sshlib.protocol.createByteString +import org.connectbot.sshlib.protocol.toByteArray +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import java.nio.ByteBuffer + +class AgentDestinationConstraintTest { + + private val noopVerifier: SessionBindVerifier = SessionBindVerifier { _, _, _ -> true } + + // Builds a minimal SSH-string encoded byte array (uint32 length + data) + private fun sshString(data: ByteArray): ByteArray { + val result = ByteArray(4 + data.size) + result[0] = (data.size shr 24).toByte() + result[1] = (data.size shr 16).toByte() + result[2] = (data.size shr 8).toByte() + result[3] = data.size.toByte() + System.arraycopy(data, 0, result, 4, data.size) + return result + } + + private fun sshString(s: String, charset: java.nio.charset.Charset = Charsets.UTF_8): ByteArray = sshString(s.toByteArray(charset)) + + /** + * Builds the signed data blob for a publickey or publickey-hostbound auth request. + * Mirrors the format used by SshConnection.buildSignatureData / buildHostBoundSignatureData. + */ + private fun buildSignedData( + sessionId: ByteArray = byteArrayOf(0xAA.toByte()), + username: String = "user", + methodName: String = "publickey", + algorithmName: String = "ssh-ed25519", + publicKeyBlob: ByteArray = byteArrayOf(1, 2, 3), + serverHostKey: ByteArray? = null, + ): ByteArray { + val parts = mutableListOf() + parts += sshString(sessionId) + parts += byteArrayOf(50) // SSH_MSG_USERAUTH_REQUEST + parts += sshString(username) + parts += sshString("ssh-connection", Charsets.US_ASCII) + parts += sshString(methodName, Charsets.US_ASCII) + parts += byteArrayOf(1) // has_signature = TRUE + parts += sshString(algorithmName, Charsets.US_ASCII) + parts += sshString(publicKeyBlob) + if (serverHostKey != null) parts += sshString(serverHostKey) + return parts.fold(ByteArray(0)) { acc, b -> acc + b } + } + + private fun buildAgentMessage(messageType: Int, payload: ByteArray): ByteArray { + val totalLength = 1 + payload.size + val buffer = ByteArray(4 + totalLength) + buffer[0] = ((totalLength shr 24) and 0xFF).toByte() + buffer[1] = ((totalLength shr 16) and 0xFF).toByte() + buffer[2] = ((totalLength shr 8) and 0xFF).toByte() + buffer[3] = (totalLength and 0xFF).toByte() + buffer[4] = messageType.toByte() + System.arraycopy(payload, 0, buffer, 5, payload.size) + return buffer + } + + private fun parseAgentMessage(response: ByteArray): Pair { + val buffer = ByteBuffer.wrap(response) + val length = buffer.int + val messageType = buffer.get().toInt() and 0xFF + val payload = ByteArray(length - 1) + buffer.get(payload) + return Pair(messageType, payload) + } + + private fun buildSessionBindRequest( + hostKey: ByteArray, + sessionId: ByteArray, + isForwarding: Int, + ): ByteArray { + val bind = SshAgentcSessionBind() + bind.setHostkey(createByteString(hostKey)) + bind.setSessionIdentifier(createByteString(sessionId)) + bind.setSignature(createByteString(byteArrayOf(0x01))) + bind.setIsForwarding(isForwarding) + bind._check() + + val nameBytes = createByteString("session-bind@openssh.com".toByteArray()).toByteArray() + val bindBytes = bind.toByteArray() + val extBytes = ByteArray(nameBytes.size + bindBytes.size) + System.arraycopy(nameBytes, 0, extBytes, 0, nameBytes.size) + System.arraycopy(bindBytes, 0, extBytes, nameBytes.size, bindBytes.size) + return buildAgentMessage(27, extBytes) + } + + private fun buildSignRequest(keyBlob: ByteArray, dataToSign: ByteArray): ByteArray { + val signRequest = SshAgentcSignRequest() + signRequest.setKeyBlob(createByteString(keyBlob)) + signRequest.setData(createByteString(dataToSign)) + signRequest.setFlags(0) + signRequest._check() + return buildAgentMessage(13, signRequest.toByteArray()) + } + + private val hostKeyA = byteArrayOf(0x10, 0x11, 0x12) + private val hostKeyB = byteArrayOf(0x20, 0x21, 0x22) + private val keyBlob = byteArrayOf(0x01, 0x02, 0x03) + + @Test + fun `unconstrained key is always allowed`() = runTest { + val provider = object : AgentProvider { + override suspend fun getIdentities() = listOf(AgentIdentity(keyBlob, "test")) + override suspend fun signData(context: AgentSigningContext) = byteArrayOf(0xFF.toByte()) + } + val handler = AgentProtocolHandler(provider, AgentSessionInfo(byteArrayOf(1), hostKeyA), noopVerifier) + + val signedData = buildSignedData(serverHostKey = hostKeyA) + val response = handler.handleRequest(buildSignRequest(keyBlob, signedData)) + + val (msgType, _) = parseAgentMessage(response) + assertEquals(14, msgType) // SSH_AGENT_SIGN_RESPONSE + } + + @Test + fun `constrained key direct connection matching destination is allowed`() = runTest { + val constraints = listOf( + DestinationConstraint( + fromHostname = "", + fromKeyspecs = emptyList(), + toUsername = "user", + toHostname = "host-a", + toHostspecs = listOf(AgentKeySpec(hostKeyA, isCa = false)), + ), + ) + val provider = object : AgentProvider { + override suspend fun getIdentities() = listOf(AgentIdentity(keyBlob, "test", constraints)) + override suspend fun signData(context: AgentSigningContext) = byteArrayOf(0xFF.toByte()) + } + val handler = AgentProtocolHandler(provider, AgentSessionInfo(byteArrayOf(1), hostKeyA), noopVerifier) + + val signedData = buildSignedData(username = "user", serverHostKey = hostKeyA) + val response = handler.handleRequest(buildSignRequest(keyBlob, signedData)) + + val (msgType, _) = parseAgentMessage(response) + assertEquals(14, msgType) // SSH_AGENT_SIGN_RESPONSE + } + + @Test + fun `constrained key direct connection wrong destination is rejected`() = runTest { + val constraints = listOf( + DestinationConstraint( + fromHostname = "", + fromKeyspecs = emptyList(), + toUsername = "user", + toHostname = "host-a", + toHostspecs = listOf(AgentKeySpec(hostKeyA, isCa = false)), + ), + ) + val provider = object : AgentProvider { + override suspend fun getIdentities() = listOf(AgentIdentity(keyBlob, "test", constraints)) + override suspend fun signData(context: AgentSigningContext) = byteArrayOf(0xFF.toByte()) + } + // sessionInfo has hostKeyB as the server key — not permitted by the constraint (which requires hostKeyA) + val handler = AgentProtocolHandler(provider, AgentSessionInfo(byteArrayOf(1), hostKeyB), noopVerifier) + + val signedData = buildSignedData(username = "user") + val response = handler.handleRequest(buildSignRequest(keyBlob, signedData)) + + val (msgType, _) = parseAgentMessage(response) + assertEquals(5, msgType) // SSH_AGENT_FAILURE + } + + @Test + fun `constrained key forwarding without hostbound method is rejected`() = runTest { + val constraints = listOf( + DestinationConstraint( + fromHostname = "hop-a", + fromKeyspecs = listOf(AgentKeySpec(hostKeyA, isCa = false)), + toUsername = "user", + toHostname = "host-b", + toHostspecs = listOf(AgentKeySpec(hostKeyB, isCa = false)), + ), + ) + val provider = object : AgentProvider { + override suspend fun getIdentities() = listOf(AgentIdentity(keyBlob, "test", constraints)) + override suspend fun signData(context: AgentSigningContext) = byteArrayOf(0xFF.toByte()) + } + val handler = AgentProtocolHandler(provider, AgentSessionInfo(byteArrayOf(1), hostKeyA), noopVerifier) + + // Simulate forwarding: bind via hostKeyA, then attempt to sign without hostbound method + handler.handleRequest(buildSessionBindRequest(hostKeyA, byteArrayOf(1), isForwarding = 0)) + handler.handleRequest(buildSessionBindRequest(hostKeyB, byteArrayOf(2), isForwarding = 1)) + + // Standard "publickey" method (not hostbound) should be rejected when forwarding + val signedData = buildSignedData(methodName = "publickey", username = "user", serverHostKey = hostKeyB) + val response = handler.handleRequest(buildSignRequest(keyBlob, signedData)) + + val (msgType, _) = parseAgentMessage(response) + assertEquals(5, msgType) // SSH_AGENT_FAILURE + } + + @Test + fun `constrained key forwarding with correct path is allowed`() = runTest { + val constraints = listOf( + DestinationConstraint( + fromHostname = "hop-a", + fromKeyspecs = listOf(AgentKeySpec(hostKeyA, isCa = false)), + toUsername = "user", + toHostname = "host-b", + toHostspecs = listOf(AgentKeySpec(hostKeyB, isCa = false)), + ), + ) + val provider = object : AgentProvider { + override suspend fun getIdentities() = listOf(AgentIdentity(keyBlob, "test", constraints)) + override suspend fun signData(context: AgentSigningContext) = byteArrayOf(0xFF.toByte()) + } + val handler = AgentProtocolHandler(provider, AgentSessionInfo(byteArrayOf(1), hostKeyA), noopVerifier) + + // Origin bind (non-forwarding) through hostKeyA, then forwarding bind through hostKeyA → hostKeyB + handler.handleRequest(buildSessionBindRequest(hostKeyA, byteArrayOf(1), isForwarding = 0)) + handler.handleRequest(buildSessionBindRequest(hostKeyB, byteArrayOf(2), isForwarding = 1)) + + val signedData = buildSignedData( + methodName = "publickey-hostbound-v00@openssh.com", + username = "user", + serverHostKey = hostKeyB, + ) + val response = handler.handleRequest(buildSignRequest(keyBlob, signedData)) + + val (msgType, _) = parseAgentMessage(response) + assertEquals(14, msgType) // SSH_AGENT_SIGN_RESPONSE + } + + @Test + fun `constrained key forwarding through wrong hop is rejected`() = runTest { + val constraints = listOf( + DestinationConstraint( + fromHostname = "hop-a", + fromKeyspecs = listOf(AgentKeySpec(hostKeyA, isCa = false)), + toUsername = "user", + toHostname = "host-b", + toHostspecs = listOf(AgentKeySpec(hostKeyB, isCa = false)), + ), + ) + val provider = object : AgentProvider { + override suspend fun getIdentities() = listOf(AgentIdentity(keyBlob, "test", constraints)) + override suspend fun signData(context: AgentSigningContext) = byteArrayOf(0xFF.toByte()) + } + val hostKeyC = byteArrayOf(0x30, 0x31, 0x32) + val handler = AgentProtocolHandler(provider, AgentSessionInfo(byteArrayOf(1), hostKeyC), noopVerifier) + + // Forwarding through hostKeyC (not in constraints) + handler.handleRequest(buildSessionBindRequest(hostKeyC, byteArrayOf(1), isForwarding = 0)) + handler.handleRequest(buildSessionBindRequest(hostKeyB, byteArrayOf(2), isForwarding = 1)) + + val signedData = buildSignedData( + methodName = "publickey-hostbound-v00@openssh.com", + username = "user", + serverHostKey = hostKeyB, + ) + val response = handler.handleRequest(buildSignRequest(keyBlob, signedData)) + + val (msgType, _) = parseAgentMessage(response) + assertEquals(5, msgType) // SSH_AGENT_FAILURE + } + + @Test + fun `REQUEST_IDENTITIES on forwarded connection only returns reachable keys`() = runTest { + val constrainedKey = byteArrayOf(0xAA.toByte(), 0xBB.toByte()) + val unconstrainedKey = byteArrayOf(0xCC.toByte(), 0xDD.toByte()) + + val constraints = listOf( + DestinationConstraint( + fromHostname = "hop-a", + fromKeyspecs = listOf(AgentKeySpec(hostKeyA, isCa = false)), + toUsername = "", + toHostname = "host-b", + toHostspecs = listOf(AgentKeySpec(hostKeyB, isCa = false)), + ), + ) + val provider = object : AgentProvider { + override suspend fun getIdentities() = listOf( + AgentIdentity(constrainedKey, "constrained", constraints), + AgentIdentity(unconstrainedKey, "unconstrained"), + ) + override suspend fun signData(context: AgentSigningContext) = null + } + val handler = AgentProtocolHandler(provider, AgentSessionInfo(byteArrayOf(1), hostKeyA), noopVerifier) + + // Add a forwarding binding through hostKeyA + handler.handleRequest(buildSessionBindRequest(hostKeyA, byteArrayOf(1), isForwarding = 0)) + + val response = handler.handleRequest(buildAgentMessage(11, ByteArray(0))) + val (msgType, payload) = parseAgentMessage(response) + assertEquals(12, msgType) // SSH_AGENT_IDENTITIES_ANSWER + + val stream = io.kaitai.struct.ByteBufferKaitaiStream(payload) + val answer = org.connectbot.sshlib.protocol.SshAgentIdentitiesAnswer(stream) + answer._read() + + // Only unconstrained key should be visible (constrained key needs fromKeyspecs=hostKeyA + // but it's found in the last hop — wait, hostKeyA IS the last hop here, so constrained + // key IS reachable). Both should appear. + assertEquals(2, answer.nkeys().toInt()) + } + + @Test + fun `constrained key not reachable via current hop is hidden`() = runTest { + val constrainedKey = byteArrayOf(0xAA.toByte(), 0xBB.toByte()) + val unconstrainedKey = byteArrayOf(0xCC.toByte(), 0xDD.toByte()) + + val hostKeyC = byteArrayOf(0x30, 0x31, 0x32) + val constraints = listOf( + DestinationConstraint( + fromHostname = "hop-a", + fromKeyspecs = listOf(AgentKeySpec(hostKeyA, isCa = false)), + toUsername = "", + toHostname = "host-b", + toHostspecs = listOf(AgentKeySpec(hostKeyB, isCa = false)), + ), + ) + val provider = object : AgentProvider { + override suspend fun getIdentities() = listOf( + AgentIdentity(constrainedKey, "constrained", constraints), + AgentIdentity(unconstrainedKey, "unconstrained"), + ) + override suspend fun signData(context: AgentSigningContext) = null + } + // Forwarding through hostKeyC (not hostKeyA, so constrained key not reachable) + val handler = AgentProtocolHandler(provider, AgentSessionInfo(byteArrayOf(1), hostKeyC), noopVerifier) + handler.handleRequest(buildSessionBindRequest(hostKeyC, byteArrayOf(1), isForwarding = 0)) + + val response = handler.handleRequest(buildAgentMessage(11, ByteArray(0))) + val (msgType, payload) = parseAgentMessage(response) + assertEquals(12, msgType) // SSH_AGENT_IDENTITIES_ANSWER + + val stream = io.kaitai.struct.ByteBufferKaitaiStream(payload) + val answer = org.connectbot.sshlib.protocol.SshAgentIdentitiesAnswer(stream) + answer._read() + + // Only unconstrained key should be visible + assertEquals(1, answer.nkeys().toInt()) + assert(answer.identities()[0].keyBlob().data().contentEquals(unconstrainedKey)) + } +} diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/AgentProtocolTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/AgentProtocolTest.kt index 96acaa2..a12bfd8 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/AgentProtocolTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/AgentProtocolTest.kt @@ -20,6 +20,7 @@ import io.kaitai.struct.ByteBufferKaitaiStream import kotlinx.coroutines.test.runTest import org.connectbot.sshlib.client.AgentProtocolHandler import org.connectbot.sshlib.client.AgentSessionInfo +import org.connectbot.sshlib.client.SessionBindVerifier import org.connectbot.sshlib.protocol.SshAgentIdentitiesAnswer import org.connectbot.sshlib.protocol.SshAgentMessage import org.connectbot.sshlib.protocol.SshAgentSignResponse @@ -241,24 +242,16 @@ class AgentProtocolTest { assertFalse(capturedContext.isBound) } - @Test - fun `handler handles session bind extension`() = runTest { - val testProvider = object : AgentProvider { - override suspend fun getIdentities(): List = emptyList() - override suspend fun signData(context: AgentSigningContext): ByteArray? = null - } - - val sessionId = byteArrayOf(1, 2, 3) - val hostKey = byteArrayOf(4, 5, 6) - val sessionInfo = AgentSessionInfo(sessionId, hostKey) - - val handler = AgentProtocolHandler(testProvider, sessionInfo) - + private fun buildSessionBindRequest( + hostKey: ByteArray, + sessionId: ByteArray, + isForwarding: Int, + ): ByteArray { val bind = SshAgentcSessionBind() bind.setHostkey(createByteString(hostKey)) bind.setSessionIdentifier(createByteString(sessionId)) bind.setSignature(createByteString(byteArrayOf(1, 2, 3))) - bind.setIsForwarding(1) + bind.setIsForwarding(isForwarding) bind._check() val nameBytes = createByteString("session-bind@openssh.com".toByteArray()).toByteArray() @@ -266,16 +259,32 @@ class AgentProtocolTest { val extBytes = ByteArray(nameBytes.size + bindBytes.size) System.arraycopy(nameBytes, 0, extBytes, 0, nameBytes.size) System.arraycopy(bindBytes, 0, extBytes, nameBytes.size, bindBytes.size) + return buildAgentMessage(27, extBytes) + } - val requestMessage = buildAgentMessage(27, extBytes) // SSH_AGENTC_EXTENSION - val response = handler.handleRequest(requestMessage) + private val noopVerifier: SessionBindVerifier = SessionBindVerifier { _, _, _ -> true } + private val rejectingVerifier: SessionBindVerifier = SessionBindVerifier { _, _, _ -> false } + + @Test + fun `handler handles session bind extension`() = runTest { + val testProvider = object : AgentProvider { + override suspend fun getIdentities(): List = emptyList() + override suspend fun signData(context: AgentSigningContext): ByteArray? = null + } + + val sessionId = byteArrayOf(1, 2, 3) + val hostKey = byteArrayOf(4, 5, 6) + val sessionInfo = AgentSessionInfo(sessionId, hostKey) + val handler = AgentProtocolHandler(testProvider, sessionInfo, noopVerifier) + + val response = handler.handleRequest(buildSessionBindRequest(hostKey, sessionId, isForwarding = 1)) val (messageType, _) = parseAgentMessage(response) assertEquals(6, messageType) // SSH_AGENT_SUCCESS } @Test - fun `handler rejects duplicate session bind`() = runTest { + fun `handler rejects session bind when signature verification fails`() = runTest { val testProvider = object : AgentProvider { override suspend fun getIdentities(): List = emptyList() override suspend fun signData(context: AgentSigningContext): ByteArray? = null @@ -284,23 +293,27 @@ class AgentProtocolTest { val sessionId = byteArrayOf(1, 2, 3) val hostKey = byteArrayOf(4, 5, 6) val sessionInfo = AgentSessionInfo(sessionId, hostKey) + val handler = AgentProtocolHandler(testProvider, sessionInfo, rejectingVerifier) - val handler = AgentProtocolHandler(testProvider, sessionInfo) + val response = handler.handleRequest(buildSessionBindRequest(hostKey, sessionId, isForwarding = 1)) - val bind = SshAgentcSessionBind() - bind.setHostkey(createByteString(hostKey)) - bind.setSessionIdentifier(createByteString(sessionId)) - bind.setSignature(createByteString(byteArrayOf(1, 2, 3))) - bind.setIsForwarding(1) - bind._check() + val (messageType, _) = parseAgentMessage(response) + assertEquals(5, messageType) // SSH_AGENT_FAILURE + } - val nameBytes = createByteString("session-bind@openssh.com".toByteArray()).toByteArray() - val bindBytes = bind.toByteArray() - val extBytes = ByteArray(nameBytes.size + bindBytes.size) - System.arraycopy(nameBytes, 0, extBytes, 0, nameBytes.size) - System.arraycopy(bindBytes, 0, extBytes, nameBytes.size, bindBytes.size) + @Test + fun `handler rejects duplicate session bind`() = runTest { + val testProvider = object : AgentProvider { + override suspend fun getIdentities(): List = emptyList() + override suspend fun signData(context: AgentSigningContext): ByteArray? = null + } - val requestMessage = buildAgentMessage(27, extBytes) + val sessionId = byteArrayOf(1, 2, 3) + val hostKey = byteArrayOf(4, 5, 6) + val sessionInfo = AgentSessionInfo(sessionId, hostKey) + val handler = AgentProtocolHandler(testProvider, sessionInfo, noopVerifier) + + val requestMessage = buildSessionBindRequest(hostKey, sessionId, isForwarding = 1) handler.handleRequest(requestMessage) val response = handler.handleRequest(requestMessage) @@ -309,35 +322,44 @@ class AgentProtocolTest { } @Test - fun `handler rejects session bind with mismatched hostkey`() = runTest { + fun `handler rejects non-forwarding session bind with mismatched hostkey`() = runTest { val testProvider = object : AgentProvider { override suspend fun getIdentities(): List = emptyList() override suspend fun signData(context: AgentSigningContext): ByteArray? = null } val sessionInfo = AgentSessionInfo(byteArrayOf(1, 2, 3), byteArrayOf(4, 5, 6)) - val handler = AgentProtocolHandler(testProvider, sessionInfo) + val handler = AgentProtocolHandler(testProvider, sessionInfo, noopVerifier) - val bind = SshAgentcSessionBind() - bind.setHostkey(createByteString(byteArrayOf(9, 9, 9))) // Mismatched - bind.setSessionIdentifier(createByteString(byteArrayOf(1, 2, 3))) - bind.setSignature(createByteString(byteArrayOf(1, 2, 3))) - bind.setIsForwarding(1) - bind._check() - - val nameBytes = createByteString("session-bind@openssh.com".toByteArray()).toByteArray() - val bindBytes = bind.toByteArray() - val extBytes = ByteArray(nameBytes.size + bindBytes.size) - System.arraycopy(nameBytes, 0, extBytes, 0, nameBytes.size) - System.arraycopy(bindBytes, 0, extBytes, nameBytes.size, bindBytes.size) - - val requestMessage = buildAgentMessage(27, extBytes) - val response = handler.handleRequest(requestMessage) + // is_forwarding = 0 means origin bind — hostkey must match sessionInfo.serverHostKey + val response = handler.handleRequest( + buildSessionBindRequest(byteArrayOf(9, 9, 9), byteArrayOf(1, 2, 3), isForwarding = 0), + ) val (messageType, _) = parseAgentMessage(response) assertEquals(5, messageType) // SSH_AGENT_FAILURE } + @Test + fun `handler accumulates multiple forwarding binds`() = runTest { + val testProvider = object : AgentProvider { + override suspend fun getIdentities(): List = emptyList() + override suspend fun signData(context: AgentSigningContext): ByteArray? = null + } + + val hostKey = byteArrayOf(4, 5, 6) + val sessionInfo = AgentSessionInfo(byteArrayOf(1, 2, 3), hostKey) + val handler = AgentProtocolHandler(testProvider, sessionInfo, noopVerifier) + + val response1 = handler.handleRequest(buildSessionBindRequest(hostKey, byteArrayOf(1, 2, 3), isForwarding = 0)) + val (type1, _) = parseAgentMessage(response1) + assertEquals(6, type1) + + val response2 = handler.handleRequest(buildSessionBindRequest(byteArrayOf(7, 8, 9), byteArrayOf(4, 5, 6), isForwarding = 1)) + val (type2, _) = parseAgentMessage(response2) + assertEquals(6, type2) + } + @Test fun `handler returns failure for unknown extension`() = runTest { val testProvider = object : AgentProvider { diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/ExtInfoProcessingTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/ExtInfoProcessingTest.kt new file mode 100644 index 0000000..18e8262 --- /dev/null +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/ExtInfoProcessingTest.kt @@ -0,0 +1,102 @@ +/* + * Copyright 2025 Kenny Root + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.connectbot.sshlib + +import io.kaitai.struct.ByteBufferKaitaiStream +import org.connectbot.sshlib.protocol.SshMsgExtInfo +import org.connectbot.sshlib.protocol.createAsciiString +import org.connectbot.sshlib.protocol.createByteString +import org.connectbot.sshlib.protocol.toByteArray +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +class ExtInfoProcessingTest { + + private fun buildExtInfo(extensions: Map): ByteArray { + val msg = SshMsgExtInfo() + msg.setNumExtensions(extensions.size.toLong()) + val extList = ArrayList() + for ((name, value) in extensions) { + val ext = SshMsgExtInfo.Extension() + ext.set_root(msg) + ext.set_parent(msg) + ext.setExtensionName(createAsciiString(name)) + ext.setExtensionValue(createByteString(value)) + ext._check() + extList.add(ext) + } + msg.setExtensions(extList) + msg._check() + return msg.toByteArray() + } + + private fun detectsHostBound(bytes: ByteArray): Boolean { + val parsed = SshMsgExtInfo(ByteBufferKaitaiStream(bytes)) + parsed._read() + return parsed.extensions().any { it.extensionName().value() == "publickey-hostbound@openssh.com" } + } + + @Test + fun `detects publickey-hostbound extension`() { + val bytes = buildExtInfo( + mapOf("publickey-hostbound@openssh.com" to "0".toByteArray()), + ) + assertTrue(detectsHostBound(bytes)) + } + + @Test + fun `does not detect hostbound when extension absent`() { + val bytes = buildExtInfo( + mapOf("server-sig-algs" to "ssh-ed25519,ssh-rsa".toByteArray()), + ) + assertFalse(detectsHostBound(bytes)) + } + + @Test + fun `handles multiple extensions including hostbound`() { + val bytes = buildExtInfo( + mapOf( + "server-sig-algs" to "ssh-ed25519".toByteArray(), + "publickey-hostbound@openssh.com" to "0".toByteArray(), + "no-flow-control" to byteArrayOf(), + ), + ) + assertTrue(detectsHostBound(bytes)) + } + + @Test + fun `handles empty extension list`() { + val bytes = buildExtInfo(emptyMap()) + assertFalse(detectsHostBound(bytes)) + } + + @Test + fun `extension count is correct`() { + val bytes = buildExtInfo( + mapOf( + "ext1" to "val1".toByteArray(), + "ext2" to "val2".toByteArray(), + ), + ) + val parsed = SshMsgExtInfo(ByteBufferKaitaiStream(bytes)) + parsed._read() + assertEquals(2, parsed.numExtensions().toInt()) + assertEquals(2, parsed.extensions().size) + } +} diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/HostBoundSignatureDataTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/HostBoundSignatureDataTest.kt new file mode 100644 index 0000000..17eeaf4 --- /dev/null +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/HostBoundSignatureDataTest.kt @@ -0,0 +1,111 @@ +/* + * Copyright 2025 Kenny Root + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.connectbot.sshlib + +import io.kaitai.struct.ByteBufferKaitaiStream +import org.connectbot.sshlib.protocol.UserauthPublickeyHostboundSignatureData +import org.connectbot.sshlib.protocol.createAsciiString +import org.connectbot.sshlib.protocol.createByteString +import org.connectbot.sshlib.protocol.toByteArray +import org.junit.jupiter.api.Assertions.assertArrayEquals +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class HostBoundSignatureDataTest { + + private fun buildHostBoundSignatureData( + sessionId: ByteArray, + username: String, + serviceName: String, + algorithmName: String, + publicKeyBlob: ByteArray, + serverHostKeyBlob: ByteArray, + ): ByteArray { + val data = UserauthPublickeyHostboundSignatureData().apply { + setSessionIdentifier(createByteString(sessionId)) + setMessageType(byteArrayOf(50)) + setUserName(createByteString(username.toByteArray(Charsets.UTF_8))) + setServiceName(createByteString(serviceName.toByteArray(Charsets.US_ASCII))) + setMethodName(createByteString("publickey-hostbound-v00@openssh.com".toByteArray(Charsets.US_ASCII))) + setHasSignature(byteArrayOf(1)) + setPublicKeyAlgorithmName(createByteString(algorithmName.toByteArray(Charsets.US_ASCII))) + setPublicKeyBlob(createByteString(publicKeyBlob)) + setServerHostKey(createByteString(serverHostKeyBlob)) + _check() + } + return data.toByteArray() + } + + @Test + fun `host-bound signature data round-trips correctly`() { + val sessionId = byteArrayOf(1, 2, 3, 4) + val username = "testuser" + val algorithmName = "ssh-ed25519" + val publicKeyBlob = byteArrayOf(10, 20, 30, 40) + val serverHostKeyBlob = byteArrayOf(50, 60, 70, 80) + + val bytes = buildHostBoundSignatureData( + sessionId = sessionId, + username = username, + serviceName = "ssh-connection", + algorithmName = algorithmName, + publicKeyBlob = publicKeyBlob, + serverHostKeyBlob = serverHostKeyBlob, + ) + + val parsed = UserauthPublickeyHostboundSignatureData(ByteBufferKaitaiStream(bytes)) + parsed._read() + + assertArrayEquals(sessionId, parsed.sessionIdentifier().data()) + assertArrayEquals(username.toByteArray(Charsets.UTF_8), parsed.userName().data()) + assertArrayEquals("ssh-connection".toByteArray(Charsets.US_ASCII), parsed.serviceName().data()) + assertArrayEquals( + "publickey-hostbound-v00@openssh.com".toByteArray(Charsets.US_ASCII), + parsed.methodName().data(), + ) + assertArrayEquals(algorithmName.toByteArray(Charsets.US_ASCII), parsed.publicKeyAlgorithmName().data()) + assertArrayEquals(publicKeyBlob, parsed.publicKeyBlob().data()) + assertArrayEquals(serverHostKeyBlob, parsed.serverHostKey().data()) + } + + @Test + fun `host-bound differs from standard publickey signature data`() { + val sessionId = byteArrayOf(1, 2, 3) + val publicKeyBlob = byteArrayOf(4, 5, 6) + val serverHostKeyBlob = byteArrayOf(7, 8, 9) + + val hostBoundBytes = buildHostBoundSignatureData( + sessionId = sessionId, + username = "user", + serviceName = "ssh-connection", + algorithmName = "ssh-ed25519", + publicKeyBlob = publicKeyBlob, + serverHostKeyBlob = serverHostKeyBlob, + ) + + // Host-bound signature data must be longer than standard (adds server_host_key) + // and method name differs. Just verify it parses and has the right method name. + val parsed = UserauthPublickeyHostboundSignatureData(ByteBufferKaitaiStream(hostBoundBytes)) + parsed._read() + + assertEquals( + "publickey-hostbound-v00@openssh.com", + String(parsed.methodName().data(), Charsets.US_ASCII), + ) + assertArrayEquals(serverHostKeyBlob, parsed.serverHostKey().data()) + } +} From 9da4156c408b1163361f8739d800bb4c30e15b97 Mon Sep 17 00:00:00 2001 From: Kenny Root Date: Tue, 28 Apr 2026 22:04:02 -0700 Subject: [PATCH 2/2] feat: negotiate RSA signature algorithm from server-sig-algs (RFC 8308) Parse server-sig-algs from SSH_MSG_EXT_INFO and use it to select the best RSA signing algorithm (rsa-sha2-512 > rsa-sha2-256 > ssh-rsa) when the server advertises the extension. Falls back to the key's default when server-sig-algs is absent, preserving current behavior for non-EXT_INFO servers. --- .../connectbot/sshlib/client/SshConnection.kt | 57 ++++++++-- .../connectbot/sshlib/crypto/Algorithms.kt | 12 ++ .../sshlib/crypto/RsaSignatureAlgorithm.kt | 16 +-- .../sshlib/ExtInfoProcessingTest.kt | 46 ++++++++ .../client/PortForwardingIntegrationTest.kt | 3 +- .../sshlib/client/SshClientIntegrationTest.kt | 106 ++++++++++++++++++ .../client/sftp/SftpClientIntegrationTest.kt | 1 + .../sshlib/crypto/AlgorithmsTest.kt | 34 ++++++ .../test/resources/openssh-server/Dockerfile | 5 +- 9 files changed, 260 insertions(+), 20 deletions(-) diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt index 07b2c9a..af69be9 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt @@ -243,6 +243,7 @@ class SshConnection( private var agentProvider: AgentProvider? = null private var serverHostKeyBlob: ByteArray? = null private var serverAdvertisesHostBound: Boolean = false + private var serverSigAlgs: Set? = null /** * Helper to manage a pending asynchronous operation that waits for a server response. @@ -546,7 +547,11 @@ class SshConnection( val publicKeyBlob = SshPublicKeyEncoder.encode(privateKey.jcaKeyPair, privateKey.keyType) - val sigAlgorithmName = privateKey.signatureAlgorithm + val sigAlgorithmName = if (privateKey.keyType == "ssh-rsa") { + negotiateRsaAlgorithm() + } else { + privateKey.signatureAlgorithm + } val sigEntry = SignatureEntry.fromSshName(sigAlgorithmName) ?: throw SshException("Unknown signature algorithm: $sigAlgorithmName") @@ -662,6 +667,18 @@ class SshConnection( return data.toByteArray() } + private fun keyBlobAlgorithmName(publicKeyBlob: ByteArray): String? { + if (publicKeyBlob.size < 4) return null + val len = ((publicKeyBlob[0].toInt() and 0xFF) shl 24) or + ((publicKeyBlob[1].toInt() and 0xFF) shl 16) or + ((publicKeyBlob[2].toInt() and 0xFF) shl 8) or + (publicKeyBlob[3].toInt() and 0xFF) + if (len <= 0 || len > publicKeyBlob.size - 4) return null + return String(publicKeyBlob, 4, len, Charsets.US_ASCII) + } + + private fun negotiateRsaAlgorithm(): String = SignatureEntry.negotiateRsaAlgorithm(serverSigAlgs) + /** * Authenticate using the strategy-based [AuthHandler] flow. * @@ -756,10 +773,15 @@ class SshConnection( key: AuthPublicKey, channel: Channel, ): InternalAuthResult { + val effectiveAlgorithmName = if (keyBlobAlgorithmName(key.publicKeyBlob) == "ssh-rsa") { + negotiateRsaAlgorithm() + } else { + key.algorithmName + } sendAuthRequest(username, "publickey") { val pubkeyAuth = UserauthRequestPublickey().apply { setHasSignature(0) - setPublicKeyAlgorithmName(createAsciiString(key.algorithmName)) + setPublicKeyAlgorithmName(createAsciiString(effectiveAlgorithmName)) setPublicKeyBlob(createByteString(key.publicKeyBlob)) _check() } @@ -778,19 +800,26 @@ class SshConnection( val hostKeyBlob = serverHostKeyBlob val useHostBound = serverAdvertisesHostBound && hostKeyBlob != null + val effectiveAlgorithmName = if (keyBlobAlgorithmName(key.publicKeyBlob) == "ssh-rsa") { + negotiateRsaAlgorithm() + } else { + key.algorithmName + } + val signatureData = if (useHostBound && hostKeyBlob != null) { - buildHostBoundSignatureData(sid, username, "ssh-connection", key.algorithmName, key.publicKeyBlob, hostKeyBlob) + buildHostBoundSignatureData(sid, username, "ssh-connection", effectiveAlgorithmName, key.publicKeyBlob, hostKeyBlob) } else { - buildSignatureData(sid, username, "ssh-connection", key.algorithmName, key.publicKeyBlob) + buildSignatureData(sid, username, "ssh-connection", effectiveAlgorithmName, key.publicKeyBlob) } - val signature = handler.onSignatureRequest(key, signatureData) ?: return false + val signingKey = if (effectiveAlgorithmName != key.algorithmName) key.copy(algorithmName = effectiveAlgorithmName) else key + val signature = handler.onSignatureRequest(signingKey, signatureData) ?: return false if (useHostBound && hostKeyBlob != null) { sendAuthRequest(username, "publickey-hostbound-v00@openssh.com") { val pubkeyAuth = UserauthRequestPublickeyHostbound().apply { setHasSignature(1) - setPublicKeyAlgorithmName(createAsciiString(key.algorithmName)) + setPublicKeyAlgorithmName(createAsciiString(effectiveAlgorithmName)) setPublicKeyBlob(createByteString(key.publicKeyBlob)) setServerHostKey(createByteString(hostKeyBlob)) setSignature(createByteString(signature)) @@ -802,7 +831,7 @@ class SshConnection( sendAuthRequest(username, "publickey") { val pubkeyAuth = UserauthRequestPublickey().apply { setHasSignature(1) - setPublicKeyAlgorithmName(createAsciiString(key.algorithmName)) + setPublicKeyAlgorithmName(createAsciiString(effectiveAlgorithmName)) setPublicKeyBlob(createByteString(key.publicKeyBlob)) setSignature(createByteString(signature)) _check() @@ -2131,9 +2160,17 @@ class SshConnection( SshEnums.MessageType.SSH_MSG_EXT_INFO -> { val extInfo = parseBody(packet) for (ext in extInfo.extensions()) { - if (ext.extensionName().value() == "publickey-hostbound@openssh.com") { - serverAdvertisesHostBound = true - logger.info("Server advertises publickey-hostbound@openssh.com") + when (ext.extensionName().value()) { + "publickey-hostbound@openssh.com" -> { + serverAdvertisesHostBound = true + logger.info("Server advertises publickey-hostbound@openssh.com") + } + + "server-sig-algs" -> { + val algs = String(ext.extensionValue().data(), Charsets.UTF_8) + serverSigAlgs = algs.split(",").filter { it.isNotEmpty() }.toSet() + logger.info("Server advertises server-sig-algs: $algs") + } } } } diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/Algorithms.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/Algorithms.kt index 83f290f..c3d12b5 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/Algorithms.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/Algorithms.kt @@ -295,6 +295,18 @@ internal enum class SignatureEntry( val defaultString: String = defaults.joinToString(",") { it.sshName } fun fromSshName(name: String): SignatureEntry? = entries.firstOrNull { it.sshName == name } + + private val rsaPreferenceOrder = listOf("rsa-sha2-512", "rsa-sha2-256", "ssh-rsa") + + /** + * Picks the best RSA signing algorithm given the server's advertised list. + * Returns "ssh-rsa" if [serverSigAlgs] is null (server didn't send the extension) + * or if no supported RSA algorithms were advertised. + */ + fun negotiateRsaAlgorithm(serverSigAlgs: Set?): String { + if (serverSigAlgs == null) return "ssh-rsa" + return rsaPreferenceOrder.firstOrNull { it in serverSigAlgs } ?: "ssh-rsa" + } } } diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/RsaSignatureAlgorithm.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/RsaSignatureAlgorithm.kt index 3f2e35e..12e2cfe 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/RsaSignatureAlgorithm.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/crypto/RsaSignatureAlgorithm.kt @@ -33,10 +33,7 @@ internal object RsaSignatureAlgorithm : SshSignatureAlgorithm { val spec = RSAPublicKeySpec(n, e) val jcaKey = KeyFactory.getInstance("RSA").generatePublic(spec) - val jcaAlgorithm = when (sig.algorithmName()) { - "rsa-sha2-512" -> "SHA512withRSA" - else -> "SHA256withRSA" - } + val jcaAlgorithm = toJcaAlgorithm(sig.algorithmName()) val sigBlob = sig.signatureBlob() as SshRsaSignatureBlob val verifier = Signature.getInstance(jcaAlgorithm) @@ -46,10 +43,7 @@ internal object RsaSignatureAlgorithm : SshSignatureAlgorithm { } override fun sign(algorithmName: String, privateKey: java.security.PrivateKey, data: ByteArray): ByteArray { - val jcaAlgorithm = when (algorithmName) { - "rsa-sha2-512" -> "SHA512withRSA" - else -> "SHA256withRSA" - } + val jcaAlgorithm = toJcaAlgorithm(algorithmName) val signer = Signature.getInstance(jcaAlgorithm) signer.initSign(privateKey) @@ -59,4 +53,10 @@ internal object RsaSignatureAlgorithm : SshSignatureAlgorithm { return encodeSshString(algorithmName.toByteArray(Charsets.US_ASCII)) + encodeSshString(sigBytes) } + + private fun toJcaAlgorithm(sshAlgorithmName: String): String = when (sshAlgorithmName) { + "rsa-sha2-512" -> "SHA512withRSA" + "rsa-sha2-256" -> "SHA256withRSA" + else -> "SHA1withRSA" + } } diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/ExtInfoProcessingTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/ExtInfoProcessingTest.kt index 18e8262..a4c5e6d 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/ExtInfoProcessingTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/ExtInfoProcessingTest.kt @@ -23,6 +23,7 @@ import org.connectbot.sshlib.protocol.createByteString import org.connectbot.sshlib.protocol.toByteArray import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNull import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test @@ -99,4 +100,49 @@ class ExtInfoProcessingTest { assertEquals(2, parsed.numExtensions().toInt()) assertEquals(2, parsed.extensions().size) } + + private fun parseServerSigAlgs(bytes: ByteArray): Set? { + val parsed = SshMsgExtInfo(ByteBufferKaitaiStream(bytes)) + parsed._read() + val ext = parsed.extensions().firstOrNull { it.extensionName().value() == "server-sig-algs" } + ?: return null + val value = String(ext.extensionValue().data(), Charsets.UTF_8) + return value.split(",").filter { it.isNotEmpty() }.toSet() + } + + @Test + fun `parses server-sig-algs with multiple algorithms`() { + val bytes = buildExtInfo( + mapOf("server-sig-algs" to "rsa-sha2-256,rsa-sha2-512,ssh-ed25519".toByteArray()), + ) + val algs = parseServerSigAlgs(bytes) + assertEquals(setOf("rsa-sha2-256", "rsa-sha2-512", "ssh-ed25519"), algs) + } + + @Test + fun `server-sig-algs absent returns null`() { + val bytes = buildExtInfo(mapOf("no-flow-control" to byteArrayOf())) + assertNull(parseServerSigAlgs(bytes)) + } + + @Test + fun `server-sig-algs with single algorithm`() { + val bytes = buildExtInfo( + mapOf("server-sig-algs" to "rsa-sha2-256".toByteArray()), + ) + val algs = parseServerSigAlgs(bytes) + assertEquals(setOf("rsa-sha2-256"), algs) + } + + @Test + fun `server-sig-algs with empty value yields empty set`() { + val bytes = buildExtInfo(mapOf("server-sig-algs" to byteArrayOf())) + assertEquals(emptySet(), parseServerSigAlgs(bytes)) + } + + @Test + fun `server-sig-algs with leading and trailing commas is handled`() { + val bytes = buildExtInfo(mapOf("server-sig-algs" to ",rsa-sha2-256,".toByteArray())) + assertEquals(setOf("rsa-sha2-256"), parseServerSigAlgs(bytes)) + } } diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/PortForwardingIntegrationTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/PortForwardingIntegrationTest.kt index 0e08e28..2c6483a 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/PortForwardingIntegrationTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/PortForwardingIntegrationTest.kt @@ -63,7 +63,8 @@ class PortForwardingIntegrationTest { @JvmStatic val opensshContainer: GenericContainer<*> = GenericContainer( ImageFromDockerfile("openssh-server-fwd-test", false) - .withFileFromClasspath(".", "openssh-server"), + .withFileFromClasspath(".", "openssh-server") + .withFileFromClasspath("test_rsa.pub", "keys/rsa_unencrypted.pub"), ) .withExposedPorts(22) .withLogConsumer(logConsumer) diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/SshClientIntegrationTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/SshClientIntegrationTest.kt index e8d1475..9fa220d 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/SshClientIntegrationTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/SshClientIntegrationTest.kt @@ -33,6 +33,7 @@ import org.connectbot.sshlib.SshClientConfig import org.connectbot.sshlib.SshException import org.connectbot.sshlib.SshSigning import org.connectbot.sshlib.blocking.BlockingSshClient +import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertNotNull import org.junit.jupiter.api.Assertions.assertNull @@ -82,6 +83,7 @@ class SshClientIntegrationTest { val opensshContainer: GenericContainer<*> = GenericContainer( ImageFromDockerfile(opensshImageName(), false) .withFileFromClasspath(".", "openssh-server") + .withFileFromClasspath("test_rsa.pub", "keys/rsa_unencrypted.pub") .withBuildArg("OPENSSH_VERSION", OPENSSH_VERSION) .withBuildArg("DEBUG_CFLAGS", DEBUG_CFLAGS), ) @@ -91,6 +93,22 @@ class SshClientIntegrationTest { Wait.forLogMessage(".*Server listening.*", 1), ) + @Container + @JvmStatic + val opensshRsaOnlyContainer: GenericContainer<*> = GenericContainer( + ImageFromDockerfile(opensshImageName() + "-rsa", false) + .withFileFromClasspath(".", "openssh-server") + .withFileFromClasspath("test_rsa.pub", "keys/rsa_unencrypted.pub") + .withBuildArg("OPENSSH_VERSION", OPENSSH_VERSION) + .withBuildArg("DEBUG_CFLAGS", DEBUG_CFLAGS), + ) + .withExposedPorts(22) + .withLogConsumer(logConsumer) + .withCommand("/usr/sbin/sshd", "-D", "-e", "-o", "HostKeyAlgorithms=ssh-rsa", "-o", "PubkeyAcceptedAlgorithms=ssh-rsa") + .waitingFor( + Wait.forLogMessage(".*Server listening.*", 1), + ) + @JvmStatic fun encryptionAlgorithms() = listOf( "aes128-gcm@openssh.com", @@ -411,6 +429,94 @@ class SshClientIntegrationTest { private fun readTestKey(): String = javaClass.getResourceAsStream("/openssh-server/test_ed25519")!! .bufferedReader().readText() + private fun readRsaTestKey(): String = javaClass.getResourceAsStream("/keys/rsa_unencrypted")!! + .bufferedReader().readText() + + @Test + fun `should connect to server only supporting ssh-rsa`() = runBlocking { + val host = opensshRsaOnlyContainer.host + val port = opensshRsaOnlyContainer.getMappedPort(22) + + val client = SshClient( + SshClientConfig { + this.host = host + this.port = port + this.hostKeyVerifier = acceptAllVerifier + // Force ssh-rsa only by excluding rsa-sha2 variants + this.hostKeyAlgorithms = "ssh-rsa" + // Use a KEX algorithm that doesn't include ext-info-c by default + this.kexAlgorithms = "diffie-hellman-group14-sha256" + }, + ) + + try { + assertTrue(client.connect() is ConnectResult.Success, "Should connect to SSH server") + + val keyData = readRsaTestKey() + val pubKey = SshSigning.getPublicKey("ssh-rsa", keyData, null) + + val handler = object : AuthHandler { + override suspend fun onPublicKeysNeeded(): List = listOf(pubKey) + override suspend fun onSignatureRequest(key: AuthPublicKey, dataToSign: ByteArray): ByteArray = SshSigning.sign(key.algorithmName, keyData, null, dataToSign) + override suspend fun onKeyboardInteractivePrompt( + name: String, + instruction: String, + prompts: List, + ): List? = null + override suspend fun onPasswordNeeded(): String? = null + } + + val result = withTimeout(10_000) { client.authenticate(USERNAME, handler) } + assertTrue(result is AuthResult.Success, "Should authenticate with ssh-rsa") + assertTrue(client.isAuthenticated, "Client should be authenticated") + } finally { + client.disconnect() + } + } + + @Test + fun `should connect with rsa-sha2-512 by default when server supports it`() = runBlocking { + val host = opensshContainer.host + val port = opensshContainer.getMappedPort(22) + + val client = SshClient( + SshClientConfig { + this.host = host + this.port = port + this.hostKeyVerifier = acceptAllVerifier + }, + ) + + try { + assertTrue(client.connect() is ConnectResult.Success, "Should connect to SSH server") + + val keyData = readRsaTestKey() + // Even if we provide the key as ssh-rsa, it should negotiate rsa-sha2-512 + val pubKey = SshSigning.getPublicKey("ssh-rsa", keyData, null) + + val handler = object : AuthHandler { + override suspend fun onPublicKeysNeeded(): List = listOf(pubKey) + override suspend fun onSignatureRequest(key: AuthPublicKey, dataToSign: ByteArray): ByteArray { + // Verify that the negotiated algorithm is rsa-sha2-512 + assertEquals("rsa-sha2-512", key.algorithmName) + return SshSigning.sign(key.algorithmName, keyData, null, dataToSign) + } + override suspend fun onKeyboardInteractivePrompt( + name: String, + instruction: String, + prompts: List, + ): List? = null + override suspend fun onPasswordNeeded(): String? = null + } + + val result = withTimeout(10_000) { client.authenticate(USERNAME, handler) } + assertTrue(result is AuthResult.Success, "Should authenticate with rsa-sha2-512") + assertTrue(client.isAuthenticated, "Client should be authenticated") + } finally { + client.disconnect() + } + } + @Test fun `auth handler should authenticate with password when preferPasswordAuth is set`() { val host = opensshContainer.host diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/sftp/SftpClientIntegrationTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/sftp/SftpClientIntegrationTest.kt index 1f397c5..abae37d 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/sftp/SftpClientIntegrationTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/sftp/SftpClientIntegrationTest.kt @@ -62,6 +62,7 @@ class SftpClientIntegrationTest { val opensshContainer: GenericContainer<*> = GenericContainer( ImageFromDockerfile("openssh-sftp-test", false) .withFileFromClasspath(".", "openssh-server") + .withFileFromClasspath("test_rsa.pub", "keys/rsa_unencrypted.pub") .withBuildArg("OPENSSH_VERSION", OPENSSH_VERSION) .withBuildArg("DEBUG_CFLAGS", DEBUG_CFLAGS), ) diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/crypto/AlgorithmsTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/crypto/AlgorithmsTest.kt index fab2463..2246fb7 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/crypto/AlgorithmsTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/crypto/AlgorithmsTest.kt @@ -331,4 +331,38 @@ class AlgorithmsTest { ) } } + + @Test + fun `negotiateRsaAlgorithm returns ssh-rsa when serverSigAlgs is null`() { + assertEquals("ssh-rsa", SignatureEntry.negotiateRsaAlgorithm(null)) + } + + @Test + fun `negotiateRsaAlgorithm prefers rsa-sha2-512 when server supports it`() { + val serverAlgs = setOf("rsa-sha2-256", "rsa-sha2-512", "ssh-ed25519") + assertEquals("rsa-sha2-512", SignatureEntry.negotiateRsaAlgorithm(serverAlgs)) + } + + @Test + fun `negotiateRsaAlgorithm falls back to rsa-sha2-256 when 512 not supported`() { + val serverAlgs = setOf("rsa-sha2-256", "ssh-ed25519") + assertEquals("rsa-sha2-256", SignatureEntry.negotiateRsaAlgorithm(serverAlgs)) + } + + @Test + fun `negotiateRsaAlgorithm falls back to ssh-rsa as last resort`() { + val serverAlgs = setOf("ssh-rsa", "ssh-ed25519") + assertEquals("ssh-rsa", SignatureEntry.negotiateRsaAlgorithm(serverAlgs)) + } + + @Test + fun `negotiateRsaAlgorithm returns ssh-rsa when no RSA algorithm in server list`() { + val serverAlgs = setOf("ssh-ed25519", "ecdsa-sha2-nistp256") + assertEquals("ssh-rsa", SignatureEntry.negotiateRsaAlgorithm(serverAlgs)) + } + + @Test + fun `negotiateRsaAlgorithm returns ssh-rsa when serverSigAlgs is empty`() { + assertEquals("ssh-rsa", SignatureEntry.negotiateRsaAlgorithm(emptySet())) + } } diff --git a/sshlib/src/test/resources/openssh-server/Dockerfile b/sshlib/src/test/resources/openssh-server/Dockerfile index b2b1f1e..19b4080 100644 --- a/sshlib/src/test/resources/openssh-server/Dockerfile +++ b/sshlib/src/test/resources/openssh-server/Dockerfile @@ -50,6 +50,7 @@ RUN mkdir -p /run/openssh && \ sed -i 's/#PasswordAuthentication.*/PasswordAuthentication yes/' /etc/ssh/sshd_config && \ echo "KexAlgorithms +diffie-hellman-group14-sha256,diffie-hellman-group14-sha1,diffie-hellman-group18-sha512,diffie-hellman-group16-sha512,diffie-hellman-group-exchange-sha256,diffie-hellman-group-exchange-sha1,diffie-hellman-group1-sha1" >> /etc/ssh/sshd_config && \ echo "HostKeyAlgorithms +ssh-rsa" >> /etc/ssh/sshd_config && \ + echo "PubkeyAcceptedAlgorithms +ssh-rsa" >> /etc/ssh/sshd_config && \ echo "Ciphers +aes128-ctr,aes256-ctr,aes128-cbc,aes256-cbc,3des-cbc" >> /etc/ssh/sshd_config && \ echo "MACs +hmac-sha2-256,hmac-sha2-512,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com" >> /etc/ssh/sshd_config && \ echo "KbdInteractiveAuthentication yes" >> /etc/ssh/sshd_config && \ @@ -60,12 +61,14 @@ RUN mkdir -p /run/openssh && \ # Set up public key authentication for test user COPY test_ed25519.pub /tmp/test_ed25519.pub +COPY test_rsa.pub /tmp/test_rsa.pub RUN mkdir -p /home/$USERNAME/.ssh && \ cat /tmp/test_ed25519.pub >> /home/$USERNAME/.ssh/authorized_keys && \ + cat /tmp/test_rsa.pub >> /home/$USERNAME/.ssh/authorized_keys && \ chown -R $USERNAME:$USERNAME /home/$USERNAME/.ssh && \ chmod 700 /home/$USERNAME/.ssh && \ chmod 600 /home/$USERNAME/.ssh/authorized_keys && \ - rm /tmp/test_ed25519.pub + rm /tmp/test_ed25519.pub /tmp/test_rsa.pub # Configure PAM for sshd RUN mkdir -p /etc/pam.d && \