Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,96 @@ internal fun interface SessionBindVerifier {
fun verify(hostKeyBlob: ByteArray, signature: ByteArray, data: ByteArray): Boolean
}

internal data class BindingEntry(val hostKeyBlob: ByteArray, val sessionId: ByteArray) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as BindingEntry
if (!hostKeyBlob.contentEquals(other.hostKeyBlob)) return false
if (!sessionId.contentEquals(other.sessionId)) return false
return true
}

override fun hashCode(): Int {
var result = hostKeyBlob.contentHashCode()
result = 31 * result + sessionId.contentHashCode()
return result
}
}

internal data class SignedDataComponents(
val methodName: String,
val destUsername: String,
val serverHostKeyBlob: ByteArray?,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as SignedDataComponents
if (methodName != other.methodName) return false
if (destUsername != other.destUsername) return false
if (!(serverHostKeyBlob == null && other.serverHostKeyBlob == null) &&
(
serverHostKeyBlob == null || other.serverHostKeyBlob == null ||
!serverHostKeyBlob.contentEquals(other.serverHostKeyBlob)
)
) {
return false
}
return true
}

override fun hashCode(): Int {
var result = methodName.hashCode()
result = 31 * result + destUsername.hashCode()
result = 31 * result + (serverHostKeyBlob?.contentHashCode() ?: 0)
return result
}
}

internal fun buildAgentMessage(messageType: Int, payload: ByteArray): ByteArray {
val totalLength = 1L + payload.size
val stream = ByteBufferKaitaiStream(4 + totalLength)
stream.writeU4be(totalLength)
stream.writeU1(messageType)
stream.writeBytes(payload)
stream.seek(0)
return stream.readBytesFull()
}

internal fun isConstraintSatisfied(
constraints: List<DestinationConstraint>,
components: SignedDataComponents,
bindingList: List<BindingEntry>,
): Boolean {
val isForwarding = bindingList.isNotEmpty()
val forwardingHopKey = if (bindingList.size >= 2) {
bindingList[bindingList.size - 2].hostKeyBlob
} else if (bindingList.size == 1) {
bindingList[0].hostKeyBlob
} else {
null
}

return constraints.any { c ->
val fromMatches = if (!isForwarding) {
c.fromHostname.isEmpty() && c.fromKeyspecs.isEmpty()
} else {
forwardingHopKey != null && c.fromKeyspecs.any { spec ->
spec.keyBlob.contentEquals(forwardingHopKey)
}
}
if (!fromMatches) return@any false

val usernameMatches = c.toUsername.isEmpty() || c.toUsername == components.destUsername
if (!usernameMatches) return@any false

val hostKeyMatches = components.serverHostKeyBlob != null &&
c.toHostspecs.any { spec -> spec.keyBlob.contentEquals(components.serverHostKeyBlob) }
hostKeyMatches
}
}

internal class AgentProtocolHandler(
private val provider: AgentProvider,
private val sessionInfo: AgentSessionInfo,
Expand All @@ -57,27 +147,9 @@ internal class AgentProtocolHandler(
const val SSH_AGENTC_EXTENSION: Int = 27
const val SSH_AGENT_SUCCESS: Int = 6

private const val METHOD_PUBLICKEY = "publickey"
private const val METHOD_PUBLICKEY_HOSTBOUND = "publickey-hostbound-v00@openssh.com"
}

private data class BindingEntry(val hostKeyBlob: ByteArray, val sessionId: ByteArray) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as BindingEntry
if (!hostKeyBlob.contentEquals(other.hostKeyBlob)) return false
if (!sessionId.contentEquals(other.sessionId)) return false
return true
}

override fun hashCode(): Int {
var result = hostKeyBlob.contentHashCode()
result = 31 * result + sessionId.contentHashCode()
return result
}
}

private val bindingList: MutableList<BindingEntry> = mutableListOf()

suspend fun handleRequest(requestBytes: ByteArray): ByteArray {
Expand Down Expand Up @@ -166,13 +238,11 @@ internal class AgentProtocolHandler(
return createFailureResponse()
}

// For direct connections with standard publickey auth (no embedded server host key),
// use the session's server host key as the implicit destination.
if (bindingList.isEmpty() && components.serverHostKeyBlob == null) {
components = components.copy(serverHostKeyBlob = sessionInfo.serverHostKey)
}

if (!isConstraintSatisfied(constraints, components)) {
if (!isConstraintSatisfied(constraints, components, bindingList)) {
logger.warn("Destination constraint not satisfied for key")
return createFailureResponse()
}
Expand Down Expand Up @@ -206,37 +276,6 @@ internal class AgentProtocolHandler(
}
}

private data class SignedDataComponents(
val methodName: String,
val destUsername: String,
val serverHostKeyBlob: ByteArray?,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as SignedDataComponents
if (methodName != other.methodName) return false
if (destUsername != other.destUsername) return false
if (!serverHostKeyBlob.contentEquals(other.serverHostKeyBlob)) return false
return true
}

override fun hashCode(): Int {
var result = methodName.hashCode()
result = 31 * result + destUsername.hashCode()
result = 31 * result + serverHostKeyBlob.contentHashCode()
return result
}
}

private fun ByteArray?.contentEquals(other: ByteArray?): Boolean {
if (this == null && other == null) return true
if (this == null || other == null) return false
return java.util.Arrays.equals(this, other)
}

private fun ByteArray?.contentHashCode(): Int = this?.contentHashCode() ?: 0

private fun parseSignedDataComponents(data: ByteArray): SignedDataComponents? = try {
val stream = ByteBufferKaitaiStream(data)
val sigData = UserauthPublickeySignatureDataAny(stream)
Expand All @@ -252,40 +291,6 @@ internal class AgentProtocolHandler(
null
}

private fun isConstraintSatisfied(
constraints: List<DestinationConstraint>,
components: SignedDataComponents,
): Boolean {
val isForwarding = bindingList.isNotEmpty()
// The forwarding hop key is the second-to-last binding (the host that relayed to us).
// The last binding is the destination's connection key, also represented in components.serverHostKeyBlob.
val forwardingHopKey = if (bindingList.size >= 2) {
bindingList[bindingList.size - 2].hostKeyBlob
} else if (bindingList.size == 1) {
bindingList[0].hostKeyBlob
} else {
null
}

return constraints.any { c ->
val fromMatches = if (!isForwarding) {
c.fromHostname.isEmpty() && c.fromKeyspecs.isEmpty()
} else {
forwardingHopKey != null && c.fromKeyspecs.any { spec ->
spec.keyBlob.contentEquals(forwardingHopKey)
}
}
if (!fromMatches) return@any false

val usernameMatches = c.toUsername.isEmpty() || c.toUsername == components.destUsername
if (!usernameMatches) return@any false

val hostKeyMatches = components.serverHostKeyBlob != null &&
c.toHostspecs.any { spec -> spec.keyBlob.contentEquals(components.serverHostKeyBlob) }
hostKeyMatches
}
}

private suspend fun handleExtension(message: SshAgentMessage): ByteArray {
logger.debug("Handling EXTENSION")

Expand All @@ -311,21 +316,18 @@ internal class AgentProtocolHandler(

val hostKeyBlob = bind.hostkey().data()
val sessionId = bind.sessionIdentifier().data()
val isForwarding = bind.isForwarding().toInt() != 0
val isForwarding = bind.isForwarding() != 0

// Replay protection: reject duplicate session IDs
if (bindingList.any { it.sessionId.contentEquals(sessionId) }) {
logger.warn("Session bind replay: session ID already recorded")
return createFailureResponse()
}

// For non-forwarding (origin) binds, verify the hostkey matches the connection's server key
if (!isForwarding && !hostKeyBlob.contentEquals(sessionInfo.serverHostKey)) {
logger.error("Session bind hostkey mismatch for non-forwarding bind")
return createFailureResponse()
}

// Cryptographically verify the session bind signature
if (!bindVerifier.verify(hostKeyBlob, bind.signature().data(), sessionId)) {
logger.error("Session bind signature verification failed")
return createFailureResponse()
Expand All @@ -343,22 +345,6 @@ internal class AgentProtocolHandler(
return payload
}

private fun buildAgentMessage(messageType: Int, payload: ByteArray): ByteArray {
val totalLength = 1 + payload.size
val buffer = ByteArray(4 + totalLength)

buffer[0] = ((totalLength shr 24) and 0xFF).toByte()
buffer[1] = ((totalLength shr 16) and 0xFF).toByte()
buffer[2] = ((totalLength shr 8) and 0xFF).toByte()
buffer[3] = (totalLength and 0xFF).toByte()

buffer[4] = messageType.toByte()

System.arraycopy(payload, 0, buffer, 5, payload.size)

return buffer
}

private fun createFailureResponse(): ByteArray = buildAgentMessage(SSH_AGENT_FAILURE, ByteArray(0))

private fun createSuccessResponse(): ByteArray = buildAgentMessage(SSH_AGENT_SUCCESS, ByteArray(0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,16 +706,6 @@ class SshConnection(
return data.toByteArray()
}

private fun keyBlobAlgorithmName(publicKeyBlob: ByteArray): String? {
if (publicKeyBlob.size < 4) return null
val len = ((publicKeyBlob[0].toInt() and 0xFF) shl 24) or
((publicKeyBlob[1].toInt() and 0xFF) shl 16) or
((publicKeyBlob[2].toInt() and 0xFF) shl 8) or
(publicKeyBlob[3].toInt() and 0xFF)
if (len <= 0 || len > publicKeyBlob.size - 4) return null
return String(publicKeyBlob, 4, len, Charsets.US_ASCII)
}

private fun negotiateRsaAlgorithm(): String = SignatureEntry.negotiateRsaAlgorithm(serverSigAlgs)

/**
Expand Down Expand Up @@ -2773,3 +2763,11 @@ internal fun selectPasswordMethods(
else -> emptyList()
}
}

internal fun keyBlobAlgorithmName(publicKeyBlob: ByteArray): String? {
if (publicKeyBlob.size < 4) return null
val stream = ByteBufferKaitaiStream(publicKeyBlob)
val len = stream.readU4be()
if (len <= 0 || len > publicKeyBlob.size - 4) return null
return String(stream.readBytes(len), Charsets.US_ASCII)
}
Loading
Loading