Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sshlib/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ package org.connectbot.sshlib.transport {
ctor public KtorTcpTransport(java.lang.String host, optional int port, optional org.connectbot.sshlib.transport.IpVersion ipVersion);
method public suspend java.lang.Object? close(kotlin.coroutines.Continuation<? super kotlin.Unit>);
method public suspend java.lang.Object? connect(kotlin.coroutines.Continuation<? super kotlin.Unit>);
method public java.net.InetSocketAddress? getLocalAddress();
method @InaccessibleFromKotlin public boolean isConnected();
method public suspend java.lang.Object? read(int count, kotlin.coroutines.Continuation<? super byte[]>);
method public suspend java.lang.Object? write(byte[] data, kotlin.coroutines.Continuation<? super kotlin.Unit>);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class KtorTcpTransport internal constructor(
override val isConnected: Boolean
get() = socket?.isClosed == false

/**
* Local TCP address assigned to the connected socket, or `null` before
* connection, after close, or when the injected socket does not expose it.
*/
fun getLocalAddress(): java.net.InetSocketAddress? = (socket as? LocalAddressProvider)?.localAddress

Comment thread
kruton marked this conversation as resolved.
/**
* Connect to the remote host.
* Must be called before any read/write operations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.Closeable
import java.net.InetAddress
import java.net.InetSocketAddress as JavaInetSocketAddress

/**
* Resolves hostnames to IP addresses.
Expand All @@ -45,6 +46,10 @@ interface TransportSocket : Closeable {
fun openWriteChannel(autoFlush: Boolean = false): ByteWriteChannel
}

internal interface LocalAddressProvider {
val localAddress: JavaInetSocketAddress?
}

/**
* Creates TCP sockets.
*/
Expand All @@ -68,10 +73,15 @@ internal class KtorTcpSocketFactory(
}
}

private class KtorTransportSocket(private val socket: Socket) : TransportSocket {
private class KtorTransportSocket(private val socket: Socket) :
TransportSocket,
LocalAddressProvider {
override val isClosed: Boolean
get() = socket.isClosed

override val localAddress: JavaInetSocketAddress?
get() = (socket.localAddress as? InetSocketAddress)?.toJavaInetSocketAddress()

override fun openReadChannel(): ByteReadChannel = socket.openReadChannel()

override fun openWriteChannel(autoFlush: Boolean): ByteWriteChannel = socket.openWriteChannel(autoFlush)
Expand All @@ -80,3 +90,9 @@ private class KtorTransportSocket(private val socket: Socket) : TransportSocket
socket.close()
}
}

private fun InetSocketAddress.toJavaInetSocketAddress(): JavaInetSocketAddress = runCatching {
JavaInetSocketAddress(InetAddress.getByAddress(hostname, resolveAddress()), port)
}.getOrElse {
JavaInetSocketAddress.createUnresolved(hostname, port)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import java.net.InetAddress
import java.net.InetSocketAddress
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNull
import kotlin.test.assertTrue

@OptIn(ExperimentalCoroutinesApi::class)
Expand All @@ -47,6 +49,27 @@ class KtorTcpTransportTest {
assertEquals(addr, factory.connectionAttempts[0])
}

@Test
fun `local address is exposed for connected socket`() = runTest {
val addr = InetAddress.getByName("127.0.0.1")
val localAddress = InetSocketAddress(InetAddress.getByName("192.0.2.10"), 54321)
val socket = MockSocket(localAddress)

val resolver = MockAddressResolver(mapOf("example.com" to listOf(addr)))
val factory = MockTcpSocketFactory(mapOf(addr to { socket }))

val transport = KtorTcpTransport("example.com", 22, resolver, factory)

assertNull(transport.getLocalAddress())
transport.connect()

assertEquals(localAddress, transport.getLocalAddress())

transport.close()

assertNull(transport.getLocalAddress())
}

@Test
fun `happy eyeballs prefers IPv6`() = runTest {
val ipv4 = InetAddress.getByName("127.0.0.1")
Expand Down Expand Up @@ -266,7 +289,10 @@ class MockTcpSocketFactory(
}
}

class MockSocket : TransportSocket {
class MockSocket(
override val localAddress: InetSocketAddress? = null,
) : TransportSocket,
LocalAddressProvider {
private var closed = false

override fun close() {
Expand Down
Loading