From 4f472363aeb1f9cdb6309ca7299b08549ea27c68 Mon Sep 17 00:00:00 2001 From: Kenny Root Date: Tue, 5 May 2026 06:13:22 -0700 Subject: [PATCH] chore(KtorTcpTransport): add way to get localSocketAddress This is needed for connection tracking in ConnectBot. If the local address for the socket is indicated in a network event, ConnectBot will either wait for it to come back or disconnect immediately. --- sshlib/api.txt | 1 + .../sshlib/transport/KtorTcpTransport.kt | 6 ++++ .../sshlib/transport/TransportDependencies.kt | 18 +++++++++++- .../sshlib/transport/KtorTcpTransportTest.kt | 28 ++++++++++++++++++- 4 files changed, 51 insertions(+), 2 deletions(-) diff --git a/sshlib/api.txt b/sshlib/api.txt index c0beca32..e9b8a769 100644 --- a/sshlib/api.txt +++ b/sshlib/api.txt @@ -791,6 +791,7 @@ package org.connectbot.sshlib.transport { ctor public KtorTcpTransport(java.lang.String host, optional int port, optional org.connectbot.sshlib.transport.IpVersion ipVersion); method public suspend java.lang.Object? close(kotlin.coroutines.Continuation); method public suspend java.lang.Object? connect(kotlin.coroutines.Continuation); + method public java.net.InetSocketAddress? getLocalAddress(); method @InaccessibleFromKotlin public boolean isConnected(); method public suspend java.lang.Object? read(int count, kotlin.coroutines.Continuation); method public suspend java.lang.Object? write(byte[] data, kotlin.coroutines.Continuation); diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/transport/KtorTcpTransport.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/transport/KtorTcpTransport.kt index 3cc397e7..f36522f0 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/transport/KtorTcpTransport.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/transport/KtorTcpTransport.kt @@ -69,6 +69,12 @@ class KtorTcpTransport internal constructor( override val isConnected: Boolean get() = socket?.isClosed == false + /** + * Local TCP address assigned to the connected socket, or `null` before + * connection, after close, or when the injected socket does not expose it. + */ + fun getLocalAddress(): java.net.InetSocketAddress? = (socket as? LocalAddressProvider)?.localAddress + /** * Connect to the remote host. * Must be called before any read/write operations. diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/transport/TransportDependencies.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/transport/TransportDependencies.kt index f1f1971f..6efb5006 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/transport/TransportDependencies.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/transport/TransportDependencies.kt @@ -28,6 +28,7 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import java.io.Closeable import java.net.InetAddress +import java.net.InetSocketAddress as JavaInetSocketAddress /** * Resolves hostnames to IP addresses. @@ -45,6 +46,10 @@ interface TransportSocket : Closeable { fun openWriteChannel(autoFlush: Boolean = false): ByteWriteChannel } +internal interface LocalAddressProvider { + val localAddress: JavaInetSocketAddress? +} + /** * Creates TCP sockets. */ @@ -68,10 +73,15 @@ internal class KtorTcpSocketFactory( } } -private class KtorTransportSocket(private val socket: Socket) : TransportSocket { +private class KtorTransportSocket(private val socket: Socket) : + TransportSocket, + LocalAddressProvider { override val isClosed: Boolean get() = socket.isClosed + override val localAddress: JavaInetSocketAddress? + get() = (socket.localAddress as? InetSocketAddress)?.toJavaInetSocketAddress() + override fun openReadChannel(): ByteReadChannel = socket.openReadChannel() override fun openWriteChannel(autoFlush: Boolean): ByteWriteChannel = socket.openWriteChannel(autoFlush) @@ -80,3 +90,9 @@ private class KtorTransportSocket(private val socket: Socket) : TransportSocket socket.close() } } + +private fun InetSocketAddress.toJavaInetSocketAddress(): JavaInetSocketAddress = runCatching { + JavaInetSocketAddress(InetAddress.getByAddress(hostname, resolveAddress()), port) +}.getOrElse { + JavaInetSocketAddress.createUnresolved(hostname, port) +} diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/transport/KtorTcpTransportTest.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/transport/KtorTcpTransportTest.kt index bbd426c6..a532791e 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/transport/KtorTcpTransportTest.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/transport/KtorTcpTransportTest.kt @@ -24,8 +24,10 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import java.net.InetAddress +import java.net.InetSocketAddress import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.assertNull import kotlin.test.assertTrue @OptIn(ExperimentalCoroutinesApi::class) @@ -47,6 +49,27 @@ class KtorTcpTransportTest { assertEquals(addr, factory.connectionAttempts[0]) } + @Test + fun `local address is exposed for connected socket`() = runTest { + val addr = InetAddress.getByName("127.0.0.1") + val localAddress = InetSocketAddress(InetAddress.getByName("192.0.2.10"), 54321) + val socket = MockSocket(localAddress) + + val resolver = MockAddressResolver(mapOf("example.com" to listOf(addr))) + val factory = MockTcpSocketFactory(mapOf(addr to { socket })) + + val transport = KtorTcpTransport("example.com", 22, resolver, factory) + + assertNull(transport.getLocalAddress()) + transport.connect() + + assertEquals(localAddress, transport.getLocalAddress()) + + transport.close() + + assertNull(transport.getLocalAddress()) + } + @Test fun `happy eyeballs prefers IPv6`() = runTest { val ipv4 = InetAddress.getByName("127.0.0.1") @@ -266,7 +289,10 @@ class MockTcpSocketFactory( } } -class MockSocket : TransportSocket { +class MockSocket( + override val localAddress: InetSocketAddress? = null, +) : TransportSocket, + LocalAddressProvider { private var closed = false override fun close() {